import json import sys from base64 import b64encode as b64enc from calendar import timegm from datetime import datetime, UTC from hashlib import sha256 from json import loads as json_loads, dumps as json_dumps from uuid import uuid4, UUID from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 from cryptography.hazmat.primitives.hashes import SHA256 from dateutil.relativedelta import relativedelta import jwt from starlette.testclient import TestClient # add relative path to use packages as they were in the app/ dir sys.path.append('../') sys.path.append('../app') from app import main from util import CASetup, PrivateKey, PublicKey, Cert client = TestClient(main.app) # Instance INSTANCE_REF = '10000000-0000-0000-0000-000000000001' ORIGIN_REF, ALLOTMENT_REF, SECRET = str(uuid4()), '20000000-0000-0000-0000-000000000001', 'HelloWorld' # CA & Signing ca_setup = CASetup(service_instance_ref=INSTANCE_REF) my_root_private_key = PrivateKey.from_file(ca_setup.root_private_key_filename) my_root_certificate = Cert.from_file(ca_setup.root_certificate_filename) my_ca_certificate = Cert.from_file(ca_setup.ca_certificate_filename) my_ca_private_key = PrivateKey.from_file(ca_setup.ca_private_key_filename) my_si_private_key = PrivateKey.from_file(ca_setup.si_private_key_filename) my_si_private_key_as_pem = my_si_private_key.pem() my_si_public_key = my_si_private_key.public_key() my_si_public_key_as_pem = my_si_private_key.public_key().pem() my_si_certificate = Cert.from_file(ca_setup.si_certificate_filename) jwt_encode_key = my_si_private_key.pem() jwt_decode_key = my_si_private_key.public_key().pem() def __bearer_token(origin_ref: str) -> str: # token = jwt.encode({"origin_ref": origin_ref}, key=jwt_encode_key, algorithm=ALGORITHMS.RS256) token = jwt.encode(payload={"origin_ref": origin_ref}, key=jwt_encode_key, algorithm='RS256') token = f'Bearer {token}' return token def test_signing(): signature_set_header = my_si_private_key.generate_signature(b'Hello') # test plain my_si_public_key.verify_signature(signature_set_header, b'Hello') # test "X-NLS-Signature: b'....' x_nls_signature_header_value = f'{signature_set_header.hex().encode()}' assert f'{x_nls_signature_header_value}'.startswith('b\'') assert f'{x_nls_signature_header_value}'.endswith('\'') # test eval signature_get_header = eval(x_nls_signature_header_value) signature_get_header = bytes.fromhex(signature_get_header.decode('ascii')) my_si_public_key.verify_signature(signature_get_header, b'Hello') def test_keypair_and_certificates(): assert my_root_certificate.public_key().mod() == my_root_private_key.public_key().mod() assert my_ca_certificate.public_key().mod() == my_ca_private_key.public_key().mod() assert my_si_certificate.public_key().mod() == my_si_public_key.mod() assert len(my_root_certificate.public_key().mod()) == 1024 assert len(my_ca_certificate.public_key().mod()) == 1024 assert len(my_si_certificate.public_key().mod()) == 512 #assert my_si_certificate.public_key().mod() != my_si_public_key.mod() my_root_certificate.public_key().raw().verify( my_ca_certificate.raw().signature, my_ca_certificate.raw().tbs_certificate_bytes, PKCS1v15(), SHA256(), ) my_ca_certificate.public_key().raw().verify( my_si_certificate.raw().signature, my_si_certificate.raw().tbs_certificate_bytes, PKCS1v15(), SHA256(), ) def test_index(): response = client.get('/') assert response.status_code == 200 def test_health(): response = client.get('/-/health') assert response.status_code == 200 assert response.json().get('status') == 'up' def test_config(): response = client.get('/-/config') assert response.status_code == 200 def test_config_root_ca(): response = client.get('/-/config/root-certificate') assert response.status_code == 200 assert response.content.decode('utf-8').strip() == my_root_certificate.pem().decode('utf-8').strip() def test_readme(): response = client.get('/-/readme') assert response.status_code == 200 def test_manage(): response = client.get('/-/manage') assert response.status_code == 200 def test_client_token(): response = client.get('/-/client-token') assert response.status_code == 200 def test_config_token(): # https://git.collinwebdesigns.de/nvidia/nls/-/blob/main/src/test/test_config_token.py response = client.post('/leasing/v1/config-token', json={"service_instance_ref": INSTANCE_REF}) assert response.status_code == 200 nv_response_certificate_configuration = response.json().get('certificateConfiguration') nv_ca_chain = nv_response_certificate_configuration.get('caChain')[0].encode('utf-8') nv_ca_chain = Cert(nv_ca_chain) nv_response_public_cert = nv_response_certificate_configuration.get('publicCert').encode('utf-8') nv_response_public_key = nv_response_certificate_configuration.get('publicKey') nv_si_certificate = Cert(nv_response_public_cert) assert nv_si_certificate.public_key().mod() == nv_response_public_key.get('mod')[0] assert nv_si_certificate.authority_key_identifier() == nv_ca_chain.subject_key_identifier() # nv_jwt_decode_key = jwk.construct(nv_response_public_cert, algorithm=ALGORITHMS.RS256) nv_response_config_token = response.json().get('configToken') #payload = jws.verify(nv_response_config_token, key=nv_jwt_decode_key, algorithms=ALGORITHMS.RS256) payload = jwt.decode(jwt=nv_response_config_token, key=nv_si_certificate.public_key().pem(), algorithms=['RS256'], options={'verify_signature': False}) assert payload.get('iss') == 'NLS Service Instance' assert payload.get('aud') == 'NLS Licensed Client' assert payload.get('service_instance_ref') == INSTANCE_REF nv_si_public_key_configuration = payload.get('service_instance_public_key_configuration') nv_si_public_key_me = nv_si_public_key_configuration.get('service_instance_public_key_me') assert len(nv_si_public_key_me.get('mod')) == 512 # nv_si_public_key_mod assert nv_si_public_key_me.get('exp') == 65537 # nv_si_public_key_exp def test_origins(): pass def test_origins_delete(): pass def test_leases(): pass def test_lease_delete(): pass def test_auth_v1_origin(): payload = { "registration_pending": False, "environment": { "guest_driver_version": "guest_driver_version", "hostname": "myhost", "ip_address_list": ["192.168.1.123"], "os_version": "os_version", "os_platform": "os_platform", "fingerprint": {"mac_address_list": ["ff:ff:ff:ff:ff:ff"]}, "host_driver_version": "host_driver_version" }, "update_pending": False, "candidate_origin_ref": ORIGIN_REF, } response = client.post('/auth/v1/origin', json=payload) assert response.status_code == 200 assert response.json().get('origin_ref') == ORIGIN_REF def auth_v1_origin_update(): payload = { "registration_pending": False, "environment": { "guest_driver_version": "guest_driver_version", "hostname": "myhost", "ip_address_list": ["192.168.1.123"], "os_version": "os_version", "os_platform": "os_platform", "fingerprint": {"mac_address_list": ["ff:ff:ff:ff:ff:ff"]}, "host_driver_version": "host_driver_version" }, "update_pending": False, "candidate_origin_ref": ORIGIN_REF, } response = client.post('/auth/v1/origin/update', json=payload) assert response.status_code == 200 assert response.json().get('origin_ref') == ORIGIN_REF def test_auth_v1_code(): payload = { "code_challenge": b64enc(sha256(SECRET.encode('utf-8')).digest()).rstrip(b'=').decode('utf-8'), "origin_ref": ORIGIN_REF, } response = client.post('/auth/v1/code', json=payload) assert response.status_code == 200 payload = jwt.decode(response.json().get('auth_code'), key=my_si_public_key_as_pem, algorithms=['RS256']) assert payload.get('origin_ref') == ORIGIN_REF def test_auth_v1_token(): cur_time = datetime.now(UTC) access_expires_on = cur_time + relativedelta(hours=1) payload = { "iat": timegm(cur_time.timetuple()), "exp": timegm(access_expires_on.timetuple()), "challenge": b64enc(sha256(SECRET.encode('utf-8')).digest()).rstrip(b'=').decode('utf-8'), "origin_ref": ORIGIN_REF, "key_ref": "00000000-0000-0000-0000-000000000000", "kid": "00000000-0000-0000-0000-000000000000" } payload = { "auth_code": jwt.encode(payload, key=jwt_encode_key, headers={'kid': payload.get('kid')}, algorithm='RS256'), "code_verifier": SECRET, } response = client.post('/auth/v1/token', json=payload) assert response.status_code == 200 token = response.json().get('auth_token') payload = jwt.decode(token, key=jwt_decode_key, algorithms=['RS256'], options={'verify_signature': False}) assert payload.get('origin_ref') == ORIGIN_REF def test_leasing_v1_lessor(): payload = { 'client_challenge': 'my_unique_string', 'fulfillment_context': { 'fulfillment_class_ref_list': [] }, 'lease_proposal_list': [{ 'license_type_qualifiers': {'count': 1}, 'product': {'name': 'NVIDIA Virtual Applications'} }], 'proposal_evaluation_mode': 'ALL_OF', 'scope_ref_list': [ALLOTMENT_REF] } response = client.post('/leasing/v1/lessor', json=payload, headers={'authorization': __bearer_token(ORIGIN_REF)}) assert response.status_code == 200 client_challenge = response.json().get('client_challenge') assert client_challenge == payload.get('client_challenge') signature = eval(response.headers.get('X-NLS-Signature')) assert len(signature) == 512 signature = bytes.fromhex(signature.decode('ascii')) assert len(signature) == 256 my_si_public_key.verify_signature(signature, response.content) lease_result_list = response.json().get('lease_result_list') assert len(lease_result_list) == 1 assert len(lease_result_list[0]['lease']['ref']) == 36 assert str(UUID(lease_result_list[0]['lease']['ref'])) == lease_result_list[0]['lease']['ref'] assert lease_result_list[0]['lease']['product_name'] == 'NVIDIA Virtual Applications' assert lease_result_list[0]['lease']['feature_name'] == 'GRID-Virtual-Apps' def test_leasing_v1_lessor_lease(): response = client.get('/leasing/v1/lessor/leases', headers={'authorization': __bearer_token(ORIGIN_REF)}) assert response.status_code == 200 active_lease_list = response.json().get('active_lease_list') assert len(active_lease_list) == 1 assert len(active_lease_list[0]) == 36 assert str(UUID(active_lease_list[0])) == active_lease_list[0] def test_leasing_v1_lease_renew(): response = client.get('/leasing/v1/lessor/leases', headers={'authorization': __bearer_token(ORIGIN_REF)}) active_lease_list = response.json().get('active_lease_list') active_lease_ref = active_lease_list[0] ### payload = {'client_challenge': 'my_unique_string'} response = client.put(f'/leasing/v1/lease/{active_lease_ref}', json=payload, headers={'authorization': __bearer_token(ORIGIN_REF)}) assert response.status_code == 200 client_challenge = response.json().get('client_challenge') assert client_challenge == payload.get('client_challenge') signature = eval(response.headers.get('X-NLS-Signature')) assert len(signature) == 512 signature = bytes.fromhex(signature.decode('ascii')) assert len(signature) == 256 my_si_public_key.verify_signature(signature, response.content) lease_ref = response.json().get('lease_ref') assert len(lease_ref) == 36 assert lease_ref == active_lease_ref def test_leasing_v1_lease_delete(): response = client.get('/leasing/v1/lessor/leases', headers={'authorization': __bearer_token(ORIGIN_REF)}) active_lease_list = response.json().get('active_lease_list') active_lease_ref = active_lease_list[0] ### response = client.delete(f'/leasing/v1/lease/{active_lease_ref}', headers={'authorization': __bearer_token(ORIGIN_REF)}) assert response.status_code == 200 lease_ref = response.json().get('lease_ref') assert len(lease_ref) == 36 assert lease_ref == active_lease_ref def test_leasing_v1_lessor_lease_remove(): # see "test_leasing_v1_lessor()" payload = { 'fulfillment_context': { 'fulfillment_class_ref_list': [] }, 'lease_proposal_list': [{ 'license_type_qualifiers': {'count': 1}, 'product': {'name': 'NVIDIA Virtual Applications'} }], 'proposal_evaluation_mode': 'ALL_OF', 'scope_ref_list': [ALLOTMENT_REF] } response = client.post('/leasing/v1/lessor', json=payload, headers={'authorization': __bearer_token(ORIGIN_REF)}) lease_result_list = response.json().get('lease_result_list') lease_ref = lease_result_list[0]['lease']['ref'] # response = client.delete('/leasing/v1/lessor/leases', headers={'authorization': __bearer_token(ORIGIN_REF)}) assert response.status_code == 200 released_lease_list = response.json().get('released_lease_list') assert len(released_lease_list) == 1 assert len(released_lease_list[0]) == 36 assert released_lease_list[0] == lease_ref