Merge branch 'v18.x-support' into 'main'

v18.x support / NLS 3.4.x compatibility

See merge request oscar.krause/fastapi-dls!46
This commit is contained in:
Oscar Krause 2025-04-11 22:22:27 +02:00
commit d318027f4e
3 changed files with 337 additions and 23 deletions

View File

@ -6,7 +6,7 @@ from datetime import datetime, timedelta, UTC
from hashlib import sha256
from json import loads as json_loads
from os import getenv as env
from os.path import join, dirname
from os.path import join, dirname, isfile
from uuid import uuid4
from dateutil.relativedelta import relativedelta
@ -21,7 +21,7 @@ from starlette.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse
from orm import Origin, Lease, init as db_init, migrate
from util import PrivateKey, PublicKey, load_file
from util import PrivateKey, PublicKey, load_file, Cert
# Load variables
load_dotenv('../version.env')
@ -50,6 +50,7 @@ LEASE_RENEWAL_PERIOD = float(env('LEASE_RENEWAL_PERIOD', 0.15))
LEASE_RENEWAL_DELTA = timedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0)))
CLIENT_TOKEN_EXPIRE_DELTA = relativedelta(years=12)
CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}']
DT_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ'
jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256)
@ -248,6 +249,7 @@ async def _client_token():
"iat": timegm(cur_time.timetuple()),
"nbf": timegm(cur_time.timetuple()),
"exp": timegm(exp_time.timetuple()),
"protocol_version": "2.0",
"update_mode": "ABSOLUTE",
"scope_ref_list": [ALLOTMENT_REF],
"fulfillment_class_ref_list": [],
@ -298,14 +300,19 @@ async def auth_v1_origin(request: Request):
Origin.create_or_update(db, data)
environment = {
'raw_env': j.get('environment')
}
environment.update(j.get('environment'))
response = {
"origin_ref": origin_ref,
"environment": j.get('environment'),
"environment": environment,
"svc_port_set_list": None,
"node_url_list": None,
"node_query_order": None,
"prompts": None,
"sync_timestamp": cur_time.isoformat()
"sync_timestamp": cur_time.strftime(DT_FORMAT)
}
return JSONr(response)
@ -331,7 +338,7 @@ async def auth_v1_origin_update(request: Request):
response = {
"environment": j.get('environment'),
"prompts": None,
"sync_timestamp": cur_time.isoformat()
"sync_timestamp": cur_time.strftime(DT_FORMAT)
}
return JSONr(response)
@ -362,7 +369,7 @@ async def auth_v1_code(request: Request):
response = {
"auth_code": auth_code,
"sync_timestamp": cur_time.isoformat(),
"sync_timestamp": cur_time.strftime(DT_FORMAT),
"prompts": None
}
@ -404,19 +411,266 @@ async def auth_v1_token(request: Request):
auth_token = jwt.encode(new_payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm=ALGORITHMS.RS256)
response = {
"expires": access_expires_on.isoformat(),
"expires": access_expires_on.strftime(DT_FORMAT),
"auth_token": auth_token,
"sync_timestamp": cur_time.isoformat(),
"sync_timestamp": cur_time.strftime(DT_FORMAT),
"prompts": None
}
return JSONr(response)
# NLS 3.4.0 - venv/lib/python3.12/site-packages/nls_services_lease/test/test_lease_single_controller.py
@app.post('/leasing/v1/config-token', description='request to get config token for lease operations')
async def leasing_v1_config_token(request: Request):
j, cur_time = json_loads((await request.body()).decode('utf-8')), datetime.now(UTC)
logger.debug(f'CALLED /leasing/v1/config-token')
logger.debug(f'Headers: {request.headers}')
logger.debug(f'Request: {j}')
# todo: THIS IS A DEMO ONLY
###
#
# https://git.collinwebdesigns.de/nvidia/nls/-/blob/main/src/test/test_config_token.py
#
###
root_private_key_filename = join(dirname(__file__), 'cert/my_demo_root_private_key.pem')
root_certificate_filename = join(dirname(__file__), 'cert/my_demo_root_certificate.pem')
ca_private_key_filename = join(dirname(__file__), 'cert/my_demo_ca_private_key.pem')
ca_certificate_filename = join(dirname(__file__), 'cert/my_demo_ca_certificate.pem')
si_private_key_filename = join(dirname(__file__), 'cert/my_demo_si_private_key.pem')
si_certificate_filename = join(dirname(__file__), 'cert/my_demo_si_certificate.pem')
def init_config_token_demo():
from cryptography import x509
from cryptography.hazmat._oid import NameOID
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric.rsa import generate_private_key
from cryptography.hazmat.primitives.serialization import Encoding
""" Create Root Key and Certificate """
# create root keypair
my_root_private_key = generate_private_key(public_exponent=65537, key_size=4096)
my_root_public_key = my_root_private_key.public_key()
# create root-certificate subject
my_root_subject = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, u'US'),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, u'California'),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, u'Nvidia'),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, u'Nvidia Licensing Service (NLS)'),
x509.NameAttribute(NameOID.COMMON_NAME, u'NLS Root CA'),
])
# create self-signed root-certificate
my_root_certificate = (
x509.CertificateBuilder()
.subject_name(my_root_subject)
.issuer_name(my_root_subject)
.public_key(my_root_public_key)
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.now(tz=UTC) - timedelta(days=1))
.not_valid_after(datetime.now(tz=UTC) + timedelta(days=365 * 10))
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
.add_extension(x509.SubjectKeyIdentifier.from_public_key(my_root_public_key), critical=False)
.sign(my_root_private_key, hashes.SHA256()))
my_root_private_key_as_pem = my_root_private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
with open(root_private_key_filename, 'wb') as f:
f.write(my_root_private_key_as_pem)
with open(root_certificate_filename, 'wb') as f:
f.write(my_root_certificate.public_bytes(encoding=Encoding.PEM))
""" Create CA (Intermediate) Key and Certificate """
# create ca keypair
my_ca_private_key = generate_private_key(public_exponent=65537, key_size=4096)
my_ca_public_key = my_ca_private_key.public_key()
# create ca-certificate subject
my_ca_subject = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, u'US'),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, u'California'),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, u'Nvidia'),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, u'Nvidia Licensing Service (NLS)'),
x509.NameAttribute(NameOID.COMMON_NAME, u'NLS Intermediate CA'),
])
# create self-signed ca-certificate
my_ca_certificate = (
x509.CertificateBuilder()
.subject_name(my_ca_subject)
.issuer_name(my_root_subject)
.public_key(my_ca_public_key)
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.now(tz=UTC) - timedelta(days=1))
.not_valid_after(datetime.now(tz=UTC) + timedelta(days=365 * 10))
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
.add_extension(x509.KeyUsage(digital_signature=False, key_encipherment=False, key_cert_sign=True,
key_agreement=False, content_commitment=False, data_encipherment=False,
crl_sign=True, encipher_only=False, decipher_only=False), critical=True)
.add_extension(x509.SubjectKeyIdentifier.from_public_key(my_ca_public_key), critical=False)
# .add_extension(x509.AuthorityKeyIdentifier.from_issuer_public_key(my_root_public_key), critical=False)
.add_extension(x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
my_root_certificate.extensions.get_extension_for_class(x509.SubjectKeyIdentifier).value
), critical=False)
.sign(my_root_private_key, hashes.SHA256()))
my_ca_private_key_as_pem = my_ca_private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
with open(ca_private_key_filename, 'wb') as f:
f.write(my_ca_private_key_as_pem)
with open(ca_certificate_filename, 'wb') as f:
f.write(my_ca_certificate.public_bytes(encoding=Encoding.PEM))
""" Create Service-Instance Key and Certificate """
# create si keypair
my_si_private_key = generate_private_key(public_exponent=65537, key_size=2048)
my_si_public_key = my_si_private_key.public_key()
my_si_private_key_as_pem = my_si_private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
my_si_public_key_as_pem = my_si_public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
with open(si_private_key_filename, 'wb') as f:
f.write(my_si_private_key_as_pem)
# with open('instance.public.pem', 'wb') as f:
# f.write(my_si_public_key_as_pem)
# create si-certificate subject
my_si_subject = x509.Name([
# x509.NameAttribute(NameOID.COMMON_NAME, INSTANCE_REF),
x509.NameAttribute(NameOID.COMMON_NAME, j.get('service_instance_ref')),
])
# create self-signed si-certificate
my_si_certificate = (
x509.CertificateBuilder()
.subject_name(my_si_subject)
.issuer_name(my_ca_subject)
.public_key(my_si_public_key)
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.now(tz=UTC) - timedelta(days=1))
.not_valid_after(datetime.now(tz=UTC) + timedelta(days=365 * 10))
.add_extension(x509.KeyUsage(digital_signature=True, key_encipherment=True, key_cert_sign=False,
key_agreement=True, content_commitment=False, data_encipherment=False,
crl_sign=False, encipher_only=False, decipher_only=False), critical=True)
.add_extension(x509.ExtendedKeyUsage([
x509.oid.ExtendedKeyUsageOID.SERVER_AUTH,
x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH]
), critical=False)
.add_extension(x509.SubjectKeyIdentifier.from_public_key(my_si_public_key), critical=False)
# .add_extension(x509.AuthorityKeyIdentifier.from_issuer_public_key(my_ca_public_key), critical=False)
.add_extension(x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
my_ca_certificate.extensions.get_extension_for_class(x509.SubjectKeyIdentifier).value
), critical=False)
.add_extension(x509.SubjectAlternativeName([
# x509.DNSName(INSTANCE_REF)
x509.DNSName(j.get('service_instance_ref'))
]), critical=False)
.sign(my_ca_private_key, hashes.SHA256()))
my_si_public_key_exp = my_si_certificate.public_key().public_numbers().e
my_si_public_key_mod = f'{my_si_certificate.public_key().public_numbers().n:x}' # hex value without "0x" prefix
with open(si_certificate_filename, 'wb') as f:
f.write(my_si_certificate.public_bytes(encoding=Encoding.PEM))
if not (isfile(root_private_key_filename)
and isfile(ca_private_key_filename)
and isfile(ca_certificate_filename)
and isfile(si_private_key_filename)
and isfile(si_certificate_filename)):
init_config_token_demo()
my_ca_certificate = Cert.from_file(ca_certificate_filename)
my_si_certificate = Cert.from_file(si_certificate_filename)
my_si_private_key = PrivateKey.from_file(si_private_key_filename)
my_si_private_key_as_pem = my_si_private_key.pem()
my_si_public_key = my_si_private_key.public_key().raw()
my_si_public_key_as_pem = my_si_private_key.public_key().pem()
""" build out payload """
cur_time = datetime.now(UTC)
exp_time = cur_time + CLIENT_TOKEN_EXPIRE_DELTA
payload = {
"iss": "NLS Service Instance",
"aud": "NLS Licensed Client",
"iat": timegm(cur_time.timetuple()),
"nbf": timegm(cur_time.timetuple()),
"exp": timegm(exp_time.timetuple()),
"protocol_version": "2.0",
"d_name": "DLS",
"service_instance_ref": j.get('service_instance_ref'),
"service_instance_public_key_configuration": {
"service_instance_public_key_me": {
"mod": hex(my_si_public_key.public_numbers().n)[2:],
"exp": int(my_si_public_key.public_numbers().e),
},
# 64 chars per line (pem default)
"service_instance_public_key_pem": my_si_public_key_as_pem.decode('utf-8').strip(),
"key_retention_mode": "LATEST_ONLY"
},
}
my_jwt_encode_key = jwk.construct(my_si_private_key_as_pem.decode('utf-8'), algorithm=ALGORITHMS.RS256)
config_token = jws.sign(payload, key=my_jwt_encode_key, headers=None, algorithm=ALGORITHMS.RS256)
response_ca_chain = my_ca_certificate.pem().decode('utf-8')
response_si_certificate = my_si_certificate.pem().decode('utf-8')
response = {
"certificateConfiguration": {
# 76 chars per line
"caChain": [response_ca_chain],
# 76 chars per line
"publicCert": response_si_certificate,
"publicKey": {
"exp": int(my_si_certificate.raw().public_key().public_numbers().e),
"mod": [hex(my_si_certificate.raw().public_key().public_numbers().n)[2:]],
},
},
"configToken": config_token,
}
logging.debug(response)
return JSONr(response, status_code=200)
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
@app.post('/leasing/v1/lessor', description='request multiple leases (borrow) for current origin')
async def leasing_v1_lessor(request: Request):
j, token, cur_time = json_loads((await request.body()).decode('utf-8')), __get_token(request), datetime.now(UTC)
logger.debug(j)
logger.debug(request.headers)
try:
token = __get_token(request)
except JWTError:
@ -427,6 +681,7 @@ async def leasing_v1_lessor(request: Request):
logger.info(f'> [ create ]: {origin_ref}: create leases for scope_ref_list {scope_ref_list}')
lease_result_list = []
# todo: for lease_proposal in lease_proposal_list
for scope_ref in scope_ref_list:
# if scope_ref not in [ALLOTMENT_REF]:
# return JSONr(status_code=500, detail=f'no service instances found for scopes: ["{scope_ref}"]')
@ -434,15 +689,20 @@ async def leasing_v1_lessor(request: Request):
lease_ref = str(uuid4())
expires = cur_time + LEASE_EXPIRE_DELTA
lease_result_list.append({
"ordinal": 0,
"ordinal": None,
"error": None,
# https://docs.nvidia.com/license-system/latest/nvidia-license-system-user-guide/index.html
"lease": {
"ref": lease_ref,
"created": cur_time.isoformat(),
"expires": expires.isoformat(),
"created": cur_time.strftime(DT_FORMAT),
"expires": expires.strftime(DT_FORMAT),
"recommended_lease_renewal": LEASE_RENEWAL_PERIOD,
"offline_lease": "true",
"license_type": "CONCURRENT_COUNTED_SINGLE"
"offline_lease": "false", # todo
"license_type": "CONCURRENT_COUNTED_SINGLE",
"lease_intent_id": None,
"metadata": None,
"feature_name": "GRID-Virtual-WS", # todo
"product_name": "NVIDIA RTX Virtual Workstation", # todo
}
})
@ -450,13 +710,16 @@ async def leasing_v1_lessor(request: Request):
Lease.create_or_update(db, data)
response = {
"client_challenge": None,
"lease_result_list": lease_result_list,
"result_code": "SUCCESS",
"sync_timestamp": cur_time.isoformat(),
"result_code": None,
"sync_timestamp": cur_time.strftime(DT_FORMAT),
"prompts": None
}
return JSONr(response)
logger.debug(response)
return JSONr(response, headers={'X-NLS-Signature': '?'})
# venv/lib/python3.9/site-packages/nls_services_lease/test/test_lease_multi_controller.py
@ -472,7 +735,7 @@ async def leasing_v1_lessor_lease(request: Request):
response = {
"active_lease_list": active_lease_list,
"sync_timestamp": cur_time.isoformat(),
"sync_timestamp": cur_time.strftime(DT_FORMAT),
"prompts": None
}
@ -495,11 +758,11 @@ async def leasing_v1_lease_renew(request: Request, lease_ref: str):
expires = cur_time + LEASE_EXPIRE_DELTA
response = {
"lease_ref": lease_ref,
"expires": expires.isoformat(),
"expires": expires.strftime(DT_FORMAT),
"recommended_lease_renewal": LEASE_RENEWAL_PERIOD,
"offline_lease": True,
"prompts": None,
"sync_timestamp": cur_time.isoformat(),
"sync_timestamp": cur_time.strftime(DT_FORMAT),
}
Lease.renew(db, entity, expires, cur_time)
@ -527,7 +790,7 @@ async def leasing_v1_lease_delete(request: Request, lease_ref: str):
response = {
"lease_ref": lease_ref,
"prompts": None,
"sync_timestamp": cur_time.isoformat(),
"sync_timestamp": cur_time.strftime(DT_FORMAT),
}
return JSONr(response)
@ -547,7 +810,7 @@ async def leasing_v1_lessor_lease_remove(request: Request):
response = {
"released_lease_list": released_lease_list,
"release_failure_list": None,
"sync_timestamp": cur_time.isoformat(),
"sync_timestamp": cur_time.strftime(DT_FORMAT),
"prompts": None
}
@ -569,7 +832,7 @@ async def leasing_v1_lessor_shutdown(request: Request):
response = {
"released_lease_list": released_lease_list,
"release_failure_list": None,
"sync_timestamp": cur_time.isoformat(),
"sync_timestamp": cur_time.strftime(DT_FORMAT),
"prompts": None
}

View File

@ -3,6 +3,7 @@ import logging
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey, generate_private_key
from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
from cryptography.x509 import load_pem_x509_certificate, Certificate
logging.basicConfig()
@ -76,6 +77,29 @@ class PublicKey:
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
class Cert:
def __init__(self, data: bytes):
self.__cert = load_pem_x509_certificate(data)
@staticmethod
def from_file(filename: str) -> "Cert":
log = logging.getLogger(__name__)
log.debug(f'Importing Certificate from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
return Cert(data=data.strip())
def raw(self) -> Certificate:
return self.__cert
def pem(self) -> bytes:
return self.__cert.public_bytes(encoding=serialization.Encoding.PEM)
def load_file(filename: str) -> bytes:
log = logging.getLogger(f'{__name__}')
log.debug(f'Loading contents of file "{filename}')

View File

@ -1,3 +1,4 @@
import json
import sys
from base64 import b64encode as b64enc
from calendar import timegm
@ -7,7 +8,7 @@ from os.path import dirname, join
from uuid import uuid4, UUID
from dateutil.relativedelta import relativedelta
from jose import jwt, jwk
from jose import jwt, jwk, jws
from jose.constants import ALGORITHMS
from starlette.testclient import TestClient
@ -20,6 +21,7 @@ from util import PrivateKey, PublicKey
client = TestClient(main.app)
INSTANCE_REF = '10000000-0000-0000-0000-000000000001'
ORIGIN_REF, ALLOTMENT_REF, SECRET = str(uuid4()), '20000000-0000-0000-0000-000000000001', 'HelloWorld'
# INSTANCE_KEY_RSA = generate_key()
@ -69,6 +71,31 @@ def test_client_token():
assert response.status_code == 200
def test_config_token(): # todo: /leasing/v1/config-token
# https://git.collinwebdesigns.de/nvidia/nls/-/blob/main/src/test/test_config_token.py
response = client.post('/leasing/v1/config-token', json={"service_instance_ref": INSTANCE_REF})
assert response.status_code == 200
nv_response_certificate_configuration = response.json().get('certificateConfiguration')
nv_response_public_cert = nv_response_certificate_configuration.get('publicCert').encode('utf-8')
nv_jwt_decode_key = jwk.construct(nv_response_public_cert, algorithm=ALGORITHMS.RS256)
nv_response_config_token = response.json().get('configToken')
payload = jws.verify(nv_response_config_token, key=nv_jwt_decode_key, algorithms=ALGORITHMS.RS256)
payload = json.loads(payload)
assert payload.get('iss') == 'NLS Service Instance'
assert payload.get('aud') == 'NLS Licensed Client'
assert payload.get('service_instance_ref') == INSTANCE_REF
nv_si_public_key_configuration = payload.get('service_instance_public_key_configuration')
nv_si_public_key_me = nv_si_public_key_configuration.get('service_instance_public_key_me')
# assert nv_si_public_key_me.get('mod') == 1 #nv_si_public_key_mod
assert len(nv_si_public_key_me.get('mod')) == 512
assert nv_si_public_key_me.get('exp') == 65537 # nv_si_public_key_exp
def test_origins():
pass