Compare commits

...

2 Commits

Author SHA1 Message Date
Oscar Krause
e6a2de40c9 fixed test deprecations 2025-03-18 10:33:52 +01:00
Oscar Krause
fd46eecfb3 created PrivateKey / PublicKey wrapper classes 2025-03-18 09:43:44 +01:00
3 changed files with 82 additions and 63 deletions

View File

@ -21,7 +21,7 @@ from starlette.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse from starlette.responses import StreamingResponse, JSONResponse as JSONr, HTMLResponse as HTMLr, Response, RedirectResponse
from orm import Origin, Lease, init as db_init, migrate from orm import Origin, Lease, init as db_init, migrate
from util import load_private_key, load_public_key, get_pem, load_file from util import PrivateKey, PublicKey, load_file
# Load variables # Load variables
load_dotenv('../version.env') load_dotenv('../version.env')
@ -42,8 +42,8 @@ DLS_PORT = int(env('DLS_PORT', '443'))
SITE_KEY_XID = str(env('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000')) SITE_KEY_XID = str(env('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000'))
INSTANCE_REF = str(env('INSTANCE_REF', '10000000-0000-0000-0000-000000000001')) INSTANCE_REF = str(env('INSTANCE_REF', '10000000-0000-0000-0000-000000000001'))
ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001')) ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001'))
INSTANCE_KEY_RSA = load_private_key(str(env('INSTANCE_KEY_RSA', join(dirname(__file__), 'cert/instance.private.pem')))) INSTANCE_KEY_RSA = PrivateKey(str(env('INSTANCE_KEY_RSA', join(dirname(__file__), 'cert/instance.private.pem'))))
INSTANCE_KEY_PUB = load_public_key(str(env('INSTANCE_KEY_PUB', join(dirname(__file__), 'cert/instance.public.pem')))) INSTANCE_KEY_PUB = PublicKey(str(env('INSTANCE_KEY_PUB', join(dirname(__file__), 'cert/instance.public.pem'))))
TOKEN_EXPIRE_DELTA = relativedelta(days=int(env('TOKEN_EXPIRE_DAYS', 1)), hours=int(env('TOKEN_EXPIRE_HOURS', 0))) TOKEN_EXPIRE_DELTA = relativedelta(days=int(env('TOKEN_EXPIRE_DAYS', 1)), hours=int(env('TOKEN_EXPIRE_HOURS', 0)))
LEASE_EXPIRE_DELTA = relativedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0))) LEASE_EXPIRE_DELTA = relativedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=int(env('LEASE_EXPIRE_HOURS', 0)))
LEASE_RENEWAL_PERIOD = float(env('LEASE_RENEWAL_PERIOD', 0.15)) LEASE_RENEWAL_PERIOD = float(env('LEASE_RENEWAL_PERIOD', 0.15))
@ -51,8 +51,8 @@ LEASE_RENEWAL_DELTA = timedelta(days=int(env('LEASE_EXPIRE_DAYS', 90)), hours=in
CLIENT_TOKEN_EXPIRE_DELTA = relativedelta(years=12) CLIENT_TOKEN_EXPIRE_DELTA = relativedelta(years=12)
CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}'] CORS_ORIGINS = str(env('CORS_ORIGINS', '')).split(',') if (env('CORS_ORIGINS')) else [f'https://{DLS_URL}']
jwt_encode_key = jwk.construct(get_pem(INSTANCE_KEY_RSA), algorithm=ALGORITHMS.RS256) jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(get_pem(INSTANCE_KEY_PUB), algorithm=ALGORITHMS.RS256) jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256)
# Logging # Logging
LOG_LEVEL = logging.DEBUG if DEBUG else logging.INFO LOG_LEVEL = logging.DEBUG if DEBUG else logging.INFO
@ -264,10 +264,10 @@ async def _client_token():
}, },
"service_instance_public_key_configuration": { "service_instance_public_key_configuration": {
"service_instance_public_key_me": { "service_instance_public_key_me": {
"mod": hex(INSTANCE_KEY_PUB.public_numbers().n)[2:], "mod": hex(INSTANCE_KEY_PUB.raw().public_numbers().n)[2:],
"exp": int(INSTANCE_KEY_PUB.public_numbers().e), "exp": int(INSTANCE_KEY_PUB.raw().public_numbers().e),
}, },
"service_instance_public_key_pem": get_pem(INSTANCE_KEY_PUB).decode('utf-8'), "service_instance_public_key_pem": INSTANCE_KEY_PUB.pem().decode('utf-8'),
"key_retention_mode": "LATEST_ONLY" "key_retention_mode": "LATEST_ONLY"
}, },
} }

View File

@ -1,8 +1,60 @@
import logging import logging
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey, generate_private_key
from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
logging.basicConfig() logging.basicConfig()
class PrivateKey:
def __init__(self, filename: str):
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
self.key = load_pem_private_key(data.strip(), password=None)
def raw(self) -> RSAPrivateKey:
return self.key
def pem(self) -> bytes:
return self.key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
@staticmethod
def generate(public_exponent: int = 65537, key_size: int = 2048) -> RSAPrivateKey:
log = logging.getLogger(__name__)
log.debug(f'Generating RSA-Key')
return generate_private_key(public_exponent=public_exponent, key_size=key_size)
class PublicKey:
def __init__(self, filename: str):
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
self.key = load_pem_public_key(data.strip())
def raw(self) -> RSAPublicKey:
return self.key
def pem(self) -> bytes:
return self.key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
def load_file(filename: str) -> bytes: def load_file(filename: str) -> bytes:
log = logging.getLogger(f'{__name__}') log = logging.getLogger(f'{__name__}')
log.debug(f'Loading contents of file "{filename}') log.debug(f'Loading contents of file "{filename}')
@ -11,53 +63,6 @@ def load_file(filename: str) -> bytes:
return content return content
def load_private_key(filename: str) -> "RSAPrivateKey":
from cryptography.hazmat.primitives.serialization import load_pem_private_key
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
return load_pem_private_key(data.strip(), password=None)
def load_public_key(filename: str) -> "RSAPublicKey":
from cryptography.hazmat.primitives.serialization import load_pem_public_key
log = logging.getLogger(__name__)
log.debug(f'Importing RSA-Key from "{filename}"')
with open(filename, 'rb') as f:
data = f.read()
return load_pem_public_key(data.strip())
def get_pem(key) -> bytes | None:
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
from cryptography.hazmat.primitives import serialization
if isinstance(key, RSAPrivateKey):
return key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
)
if isinstance(key, RSAPublicKey):
return key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
def generate_private_key() -> "RSAPrivateKey":
from cryptography.hazmat.primitives.asymmetric import rsa
log = logging.getLogger(__name__)
log.debug(f'Generating RSA-Key')
return rsa.generate_private_key(public_exponent=65537, key_size=2048)
class NV: class NV:
__DRIVER_MATRIX_FILENAME = 'static/driver_matrix.json' __DRIVER_MATRIX_FILENAME = 'static/driver_matrix.json'
__DRIVER_MATRIX: None | dict = None # https://docs.nvidia.com/grid/ => "Driver Versions" __DRIVER_MATRIX: None | dict = None # https://docs.nvidia.com/grid/ => "Driver Versions"

View File

@ -16,7 +16,7 @@ sys.path.append('../')
sys.path.append('../app') sys.path.append('../app')
from app import main from app import main
from util import load_private_key, load_public_key, get_pem from util import PrivateKey, PublicKey
client = TestClient(main.app) client = TestClient(main.app)
@ -25,11 +25,11 @@ ORIGIN_REF, ALLOTMENT_REF, SECRET = str(uuid4()), '20000000-0000-0000-0000-00000
# INSTANCE_KEY_RSA = generate_key() # INSTANCE_KEY_RSA = generate_key()
# INSTANCE_KEY_PUB = INSTANCE_KEY_RSA.public_key() # INSTANCE_KEY_PUB = INSTANCE_KEY_RSA.public_key()
INSTANCE_KEY_RSA = load_private_key(str(join(dirname(__file__), '../app/cert/instance.private.pem'))) INSTANCE_KEY_RSA = PrivateKey(str(join(dirname(__file__), '../app/cert/instance.private.pem')))
INSTANCE_KEY_PUB = load_public_key(str(join(dirname(__file__), '../app/cert/instance.public.pem'))) INSTANCE_KEY_PUB = PublicKey(str(join(dirname(__file__), '../app/cert/instance.public.pem')))
jwt_encode_key = jwk.construct(get_pem(INSTANCE_KEY_RSA), algorithm=ALGORITHMS.RS256) jwt_encode_key = jwk.construct(INSTANCE_KEY_RSA.pem(), algorithm=ALGORITHMS.RS256)
jwt_decode_key = jwk.construct(get_pem(INSTANCE_KEY_PUB), algorithm=ALGORITHMS.RS256) jwt_decode_key = jwk.construct(INSTANCE_KEY_PUB.pem(), algorithm=ALGORITHMS.RS256)
def __bearer_token(origin_ref: str) -> str: def __bearer_token(origin_ref: str) -> str:
@ -187,8 +187,6 @@ def test_leasing_v1_lessor():
assert len(lease_result_list[0]['lease']['ref']) == 36 assert len(lease_result_list[0]['lease']['ref']) == 36
assert str(UUID(lease_result_list[0]['lease']['ref'])) == lease_result_list[0]['lease']['ref'] assert str(UUID(lease_result_list[0]['lease']['ref'])) == lease_result_list[0]['lease']['ref']
return lease_result_list[0]['lease']['ref']
def test_leasing_v1_lessor_lease(): def test_leasing_v1_lessor_lease():
response = client.get('/leasing/v1/lessor/leases', headers={'authorization': __bearer_token(ORIGIN_REF)}) response = client.get('/leasing/v1/lessor/leases', headers={'authorization': __bearer_token(ORIGIN_REF)})
@ -231,7 +229,23 @@ def test_leasing_v1_lease_delete():
def test_leasing_v1_lessor_lease_remove(): def test_leasing_v1_lessor_lease_remove():
lease_ref = test_leasing_v1_lessor() # see "test_leasing_v1_lessor()"
payload = {
'fulfillment_context': {
'fulfillment_class_ref_list': []
},
'lease_proposal_list': [{
'license_type_qualifiers': {'count': 1},
'product': {'name': 'NVIDIA RTX Virtual Workstation'}
}],
'proposal_evaluation_mode': 'ALL_OF',
'scope_ref_list': [ALLOTMENT_REF]
}
response = client.post('/leasing/v1/lessor', json=payload, headers={'authorization': __bearer_token(ORIGIN_REF)})
lease_result_list = response.json().get('lease_result_list')
lease_ref = lease_result_list[0]['lease']['ref']
#
response = client.delete('/leasing/v1/lessor/leases', headers={'authorization': __bearer_token(ORIGIN_REF)}) response = client.delete('/leasing/v1/lessor/leases', headers={'authorization': __bearer_token(ORIGIN_REF)})
assert response.status_code == 200 assert response.status_code == 200