mirror of
https://git.collinwebdesigns.de/oscar.krause/fastapi-dls.git
synced 2025-08-28 15:36:59 +08:00
Compare commits
20 Commits
02602597f9
...
124b874b59
Author | SHA1 | Date | |
---|---|---|---|
|
124b874b59 | ||
|
04914740a4 | ||
|
6af9cd04c9 | ||
|
29268b1658 | ||
|
938a112b8a | ||
|
16870e9d67 | ||
|
55b7437fe7 | ||
|
e7e007a45f | ||
|
161a1430cf | ||
|
1ccb203b25 | ||
|
6c1a8d42dc | ||
|
d248496f34 | ||
|
fd1babaca5 | ||
|
cd9c655d65 | ||
|
6ed4bdfe6f | ||
|
e1ae757a50 | ||
|
b0ca5d7ab5 | ||
|
14f8b54752 | ||
|
dc783e6518 | ||
|
6b54d4794b |
@ -376,7 +376,7 @@ deploy:pacman:
|
||||
release:
|
||||
image: registry.gitlab.com/gitlab-org/release-cli:latest
|
||||
stage: .post
|
||||
needs: [ build:docker, build:apt, build:pacman ]
|
||||
needs: [ deploy:docker, deploy:apt, deploy:pacman ]
|
||||
rules:
|
||||
- if: $CI_COMMIT_TAG
|
||||
script:
|
||||
|
31
README.md
31
README.md
@ -417,19 +417,20 @@ After first success you have to replace `--issue` with `--renew`.
|
||||
|
||||
# Configuration
|
||||
|
||||
| Variable | Default | Usage |
|
||||
|--------------------------|----------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `DEBUG` | `false` | Toggles `fastapi` debug mode |
|
||||
| `DLS_URL` | `localhost` | Used in client-token to tell guest driver where dls instance is reachable |
|
||||
| `DLS_PORT` | `443` | Used in client-token to tell guest driver where dls instance is reachable |
|
||||
| `TOKEN_EXPIRE_DAYS` | `1` | Client auth-token validity (used for authenticate client against api, **not `.tok` file!**) |
|
||||
| `LEASE_EXPIRE_DAYS` | `90` | Lease time in days |
|
||||
| `LEASE_RENEWAL_PERIOD` | `0.15` | The percentage of the lease period that must elapse before a licensed client can renew a license \*1 |
|
||||
| `DATABASE` | `sqlite:///db.sqlite` | See [official SQLAlchemy docs](https://docs.sqlalchemy.org/en/14/core/engines.html) |
|
||||
| `CORS_ORIGINS` | `https://{DLS_URL}` | Sets `Access-Control-Allow-Origin` header (comma separated string) \*2 |
|
||||
| `SITE_KEY_XID` | `00000000-0000-0000-0000-000000000000` | Site identification uuid |
|
||||
| `INSTANCE_REF` | `10000000-0000-0000-0000-000000000001` | Instance identification uuid |
|
||||
| `ALLOTMENT_REF` | `20000000-0000-0000-0000-000000000001` | Allotment identification uuid | |
|
||||
| Variable | Default | Usage |
|
||||
|------------------------|----------------------------------------|------------------------------------------------------------------------------------------------------|
|
||||
| `DEBUG` | `false` | Toggles `fastapi` debug mode |
|
||||
| `DLS_URL` | `localhost` | Used in client-token to tell guest driver where dls instance is reachable |
|
||||
| `DLS_PORT` | `443` | Used in client-token to tell guest driver where dls instance is reachable |
|
||||
| `CERT_PATH` | `None` | Path to a Directory where generated Certificates are stored. Defaults to `/<app-dir>/cert`. |
|
||||
| `TOKEN_EXPIRE_DAYS` | `1` | Client auth-token validity (used for authenticate client against api, **not `.tok` file!**) |
|
||||
| `LEASE_EXPIRE_DAYS` | `90` | Lease time in days |
|
||||
| `LEASE_RENEWAL_PERIOD` | `0.15` | The percentage of the lease period that must elapse before a licensed client can renew a license \*1 |
|
||||
| `DATABASE` | `sqlite:///db.sqlite` | See [official SQLAlchemy docs](https://docs.sqlalchemy.org/en/14/core/engines.html) |
|
||||
| `CORS_ORIGINS` | `https://{DLS_URL}` | Sets `Access-Control-Allow-Origin` header (comma separated string) \*2 |
|
||||
| `SITE_KEY_XID` | `00000000-0000-0000-0000-000000000000` | Site identification uuid |
|
||||
| `INSTANCE_REF` | `10000000-0000-0000-0000-000000000001` | Instance identification uuid |
|
||||
| `ALLOTMENT_REF` | `20000000-0000-0000-0000-000000000001` | Allotment identification uuid |
|
||||
|
||||
\*1 For example, if the lease period is one day and the renewal period is 20%, the client attempts to renew its license
|
||||
every 4.8 hours. If network connectivity is lost, the loss of connectivity is detected during license renewal and the
|
||||
@ -535,9 +536,9 @@ Status endpoint, used for *healthcheck*.
|
||||
|
||||
Shows current runtime environment variables and their values.
|
||||
|
||||
**`GET /-/config/root-ca`**
|
||||
**`GET /-/config/root-certificate`**
|
||||
|
||||
Returns the Root-CA Certificate which is used. This is required for patching `nvidia-gridd` on 18.x releases.
|
||||
Returns the Root-Certificate Certificate which is used. This is required for patching `nvidia-gridd` on 18.x releases.
|
||||
|
||||
**`GET /-/readme`**
|
||||
|
||||
|
44
app/main.py
44
app/main.py
@ -7,6 +7,7 @@ from hashlib import sha256
|
||||
from json import loads as json_loads, dumps as json_dumps
|
||||
from os import getenv as env
|
||||
from os.path import join, dirname
|
||||
from textwrap import wrap
|
||||
from uuid import uuid4
|
||||
|
||||
from dateutil.relativedelta import relativedelta
|
||||
@ -39,6 +40,7 @@ db_init(db), migrate(db)
|
||||
# Load DLS variables (all prefixed with "INSTANCE_*" is used as "SERVICE_INSTANCE_*" or "SI_*" in official dls service)
|
||||
DLS_URL = str(env('DLS_URL', 'localhost'))
|
||||
DLS_PORT = int(env('DLS_PORT', '443'))
|
||||
CERT_PATH = str(env('CERT_PATH', None))
|
||||
SITE_KEY_XID = str(env('SITE_KEY_XID', '00000000-0000-0000-0000-000000000000'))
|
||||
INSTANCE_REF = str(env('INSTANCE_REF', '10000000-0000-0000-0000-000000000001'))
|
||||
ALLOTMENT_REF = str(env('ALLOTMENT_REF', '20000000-0000-0000-0000-000000000001'))
|
||||
@ -52,7 +54,9 @@ DT_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ'
|
||||
PRODUCT_MAPPING = ProductMapping(filename=join(dirname(__file__), 'static/product_mapping.json'))
|
||||
|
||||
# Create certificate chain and signing keys
|
||||
ca_setup = CASetup(service_instance_ref=INSTANCE_REF)
|
||||
ca_setup = CASetup(service_instance_ref=INSTANCE_REF, cert_path=CERT_PATH)
|
||||
my_root_private_key = PrivateKey.from_file(ca_setup.root_private_key_filename)
|
||||
my_root_public_key = my_root_private_key.public_key()
|
||||
my_root_certificate = Cert.from_file(ca_setup.root_certificate_filename)
|
||||
my_ca_certificate = Cert.from_file(ca_setup.ca_certificate_filename)
|
||||
my_si_certificate = Cert.from_file(ca_setup.si_certificate_filename)
|
||||
@ -151,10 +155,9 @@ async def _config():
|
||||
return Response(content=json_dumps(response), media_type='application/json', status_code=200)
|
||||
|
||||
|
||||
|
||||
@app.get('/-/config/root-ca', summary='* Root CA', description='returns Root-CA needed for patching nvidia-gridd')
|
||||
@app.get('/-/config/root-certificate', summary='* Root Certificate', description='returns Root--Certificate needed for patching nvidia-gridd')
|
||||
async def _config():
|
||||
return Response(content=my_root_certificate.pem().decode('utf-8'), media_type='text/plain')
|
||||
return Response(content=my_root_certificate.pem().decode('ascii').strip(), media_type='text/plain')
|
||||
|
||||
|
||||
@app.get('/-/readme', summary='* Readme')
|
||||
@ -287,7 +290,7 @@ async def _client_token():
|
||||
"mod": my_si_public_key.mod(),
|
||||
"exp": my_si_public_key.exp(),
|
||||
},
|
||||
"service_instance_public_key_pem": my_si_private_key.public_key().pem().decode('utf-8'),
|
||||
"service_instance_public_key_pem": my_si_public_key.pem().decode('utf-8').strip(),
|
||||
"key_retention_mode": "LATEST_ONLY"
|
||||
},
|
||||
}
|
||||
@ -462,8 +465,7 @@ async def leasing_v1_config_token(request: Request):
|
||||
"mod": my_si_public_key.mod(),
|
||||
"exp": my_si_public_key.exp(),
|
||||
},
|
||||
# 64 chars per line (pem default)
|
||||
"service_instance_public_key_pem": my_si_private_key.public_key().pem().decode('utf-8').strip(),
|
||||
"service_instance_public_key_pem": my_si_public_key.pem().decode('utf-8').strip(),
|
||||
"key_retention_mode": "LATEST_ONLY"
|
||||
},
|
||||
}
|
||||
@ -471,18 +473,34 @@ async def leasing_v1_config_token(request: Request):
|
||||
my_jwt_encode_key = jwk.construct(my_si_private_key.pem().decode('utf-8'), algorithm=ALGORITHMS.RS256)
|
||||
config_token = jws.sign(payload, key=my_jwt_encode_key, headers=None, algorithm=ALGORITHMS.RS256)
|
||||
|
||||
response_ca_chain = my_ca_certificate.pem().decode('utf-8').strip().replace('\n', '\r\n')
|
||||
response_si_certificate = my_si_certificate.pem().decode('utf-8').strip().replace('\n', '\r\n')
|
||||
response_ca_chain = my_ca_certificate.pem().decode('utf-8').strip().replace('\n', '\r\n') # 76 chars per line on original response
|
||||
"""
|
||||
response_ca_chain = my_ca_certificate.pem().decode('utf-8').strip()
|
||||
response_ca_chain = response_ca_chain.replace('-----BEGIN CERTIFICATE-----', '')
|
||||
response_ca_chain = response_ca_chain.replace('-----END CERTIFICATE-----', '')
|
||||
response_ca_chain = response_ca_chain.replace('\n', '')
|
||||
response_ca_chain = wrap(response_ca_chain, 76)
|
||||
response_ca_chain = '\r\n'.join(response_ca_chain)
|
||||
response_ca_chain = f'-----BEGIN CERTIFICATE-----\r\n{response_ca_chain}\r\n-----END CERTIFICATE-----'
|
||||
"""
|
||||
response_si_certificate = my_si_certificate.pem().decode('utf-8').strip().replace('\n', '\r\n') # 76 chars per line on original response
|
||||
"""
|
||||
response_si_certificate = my_si_certificate.pem().decode('utf-8').strip()
|
||||
response_si_certificate = response_si_certificate.replace('-----BEGIN CERTIFICATE-----', '')
|
||||
response_si_certificate = response_si_certificate.replace('-----END CERTIFICATE-----', '')
|
||||
response_si_certificate = response_si_certificate.replace('\n', '')
|
||||
response_si_certificate = wrap(response_si_certificate, 76)
|
||||
response_si_certificate = '\r\n'.join(response_si_certificate)
|
||||
response_si_certificate = f'-----BEGIN CERTIFICATE-----\r\n{response_si_certificate}\r\n-----END CERTIFICATE-----'
|
||||
"""
|
||||
|
||||
response = {
|
||||
"certificateConfiguration": {
|
||||
# 76 chars per line
|
||||
"caChain": [response_ca_chain],
|
||||
# 76 chars per line
|
||||
"publicCert": response_si_certificate,
|
||||
"publicKey": {
|
||||
"exp": int(my_si_certificate.raw().public_key().public_numbers().e),
|
||||
"mod": [hex(my_si_certificate.raw().public_key().public_numbers().n)[2:]],
|
||||
"exp": my_si_certificate.public_key().exp(),
|
||||
"mod": [my_si_certificate.public_key().mod()],
|
||||
},
|
||||
},
|
||||
"configToken": config_token,
|
||||
|
42
app/util.py
42
app/util.py
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from datetime import datetime, UTC, timedelta
|
||||
from json import loads as json_loads
|
||||
from os.path import join, dirname, isfile
|
||||
from os.path import join, dirname, isfile, isdir
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat._oid import NameOID
|
||||
@ -38,9 +38,13 @@ class CASetup:
|
||||
SI_PRIVATE_KEY_FILENAME = 'si_private_key.pem'
|
||||
SI_CERTIFICATE_FILENAME = 'si_certificate.pem'
|
||||
|
||||
def __init__(self, service_instance_ref: str):
|
||||
def __init__(self, service_instance_ref: str, cert_path: str = None):
|
||||
cert_path_prefix = join(dirname(__file__), 'cert')
|
||||
if cert_path is not None and len(cert_path) > 0 and isdir(cert_path):
|
||||
cert_path_prefix = cert_path
|
||||
|
||||
self.service_instance_ref = service_instance_ref
|
||||
self.root_private_key_filename = join(dirname(__file__), 'cert', CASetup.ROOT_PRIVATE_KEY_FILENAME)
|
||||
self.root_private_key_filename = join(cert_path_prefix, CASetup.ROOT_PRIVATE_KEY_FILENAME)
|
||||
self.root_certificate_filename = join(dirname(__file__), 'cert', CASetup.ROOT_CERTIFICATE_FILENAME)
|
||||
self.ca_private_key_filename = join(dirname(__file__), 'cert', CASetup.CA_PRIVATE_KEY_FILENAME)
|
||||
self.ca_certificate_filename = join(dirname(__file__), 'cert', CASetup.CA_CERTIFICATE_FILENAME)
|
||||
@ -81,7 +85,20 @@ class CASetup:
|
||||
.not_valid_before(datetime.now(tz=UTC) - timedelta(days=1))
|
||||
.not_valid_after(datetime.now(tz=UTC) + timedelta(days=365 * 10))
|
||||
.add_extension(x509.BasicConstraints(ca=True, path_length=None), critical=True)
|
||||
.add_extension(x509.KeyUsage(
|
||||
digital_signature=False,
|
||||
key_encipherment=False,
|
||||
key_cert_sign=True,
|
||||
key_agreement=False,
|
||||
content_commitment=False,
|
||||
data_encipherment=False,
|
||||
crl_sign=True,
|
||||
encipher_only=False,
|
||||
decipher_only=False),
|
||||
critical=True
|
||||
)
|
||||
.add_extension(x509.SubjectKeyIdentifier.from_public_key(my_root_public_key), critical=False)
|
||||
.add_extension(x509.AuthorityKeyIdentifier.from_issuer_public_key(my_root_public_key), critical=False)
|
||||
.sign(my_root_private_key, hashes.SHA256()))
|
||||
|
||||
my_root_private_key_as_pem = my_root_private_key.private_bytes(
|
||||
@ -134,7 +151,6 @@ class CASetup:
|
||||
critical=True
|
||||
)
|
||||
.add_extension(x509.SubjectKeyIdentifier.from_public_key(my_ca_public_key), critical=False)
|
||||
# .add_extension(x509.AuthorityKeyIdentifier.from_issuer_public_key(my_root_public_key), critical=False)
|
||||
.add_extension(x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
|
||||
my_root_certificate.extensions.get_extension_for_class(x509.SubjectKeyIdentifier).value
|
||||
), critical=False)
|
||||
@ -314,16 +330,22 @@ class Cert:
|
||||
def pem(self) -> bytes:
|
||||
return self.__cert.public_bytes(encoding=serialization.Encoding.PEM)
|
||||
|
||||
def public_key(self) -> "PublicKey":
|
||||
data = self.__cert.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
||||
)
|
||||
return PublicKey(data=data)
|
||||
|
||||
def signature(self) -> bytes:
|
||||
return self.__cert.signature
|
||||
|
||||
def subject_key_identifier(self):
|
||||
return self.__cert.extensions.get_extension_for_class(x509.SubjectKeyIdentifier).value.key_identifier
|
||||
|
||||
def authority_key_identifier(self):
|
||||
return self.__cert.extensions.get_extension_for_class(x509.AuthorityKeyIdentifier).value.key_identifier
|
||||
|
||||
def load_file(filename: str) -> bytes:
|
||||
log = logging.getLogger(f'{__name__}')
|
||||
log.debug(f'Loading contents of file "{filename}')
|
||||
with open(filename, 'rb') as file:
|
||||
content = file.read()
|
||||
return content
|
||||
|
||||
class DriverMatrix:
|
||||
__DRIVER_MATRIX_FILENAME = 'static/driver_matrix.json'
|
||||
|
49
test/main.py
49
test/main.py
@ -6,6 +6,8 @@ from datetime import datetime, UTC
|
||||
from hashlib import sha256
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15
|
||||
from cryptography.hazmat.primitives.hashes import SHA256
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from jose import jwt, jwk, jws
|
||||
from jose.constants import ALGORITHMS
|
||||
@ -26,11 +28,15 @@ ORIGIN_REF, ALLOTMENT_REF, SECRET = str(uuid4()), '20000000-0000-0000-0000-00000
|
||||
|
||||
# CA & Signing
|
||||
ca_setup = CASetup(service_instance_ref=INSTANCE_REF)
|
||||
my_root_private_key = PrivateKey.from_file(ca_setup.root_private_key_filename)
|
||||
my_root_certificate = Cert.from_file(ca_setup.root_certificate_filename)
|
||||
my_ca_certificate = Cert.from_file(ca_setup.ca_certificate_filename)
|
||||
my_ca_private_key = PrivateKey.from_file(ca_setup.ca_private_key_filename)
|
||||
my_si_private_key = PrivateKey.from_file(ca_setup.si_private_key_filename)
|
||||
my_si_private_key_as_pem = my_si_private_key.pem()
|
||||
my_si_public_key = my_si_private_key.public_key()
|
||||
my_si_public_key_as_pem = my_si_private_key.public_key().pem()
|
||||
my_si_certificate = Cert.from_file(ca_setup.si_certificate_filename)
|
||||
|
||||
jwt_encode_key = jwk.construct(my_si_private_key_as_pem, algorithm=ALGORITHMS.RS256)
|
||||
jwt_decode_key = jwk.construct(my_si_public_key_as_pem, algorithm=ALGORITHMS.RS256)
|
||||
@ -59,6 +65,31 @@ def test_signing():
|
||||
my_si_public_key.verify_signature(signature_get_header, b'Hello')
|
||||
|
||||
|
||||
def test_keypair_and_certificates():
|
||||
assert my_root_certificate.public_key().mod() == my_root_private_key.public_key().mod()
|
||||
assert my_ca_certificate.public_key().mod() == my_ca_private_key.public_key().mod()
|
||||
assert my_si_certificate.public_key().mod() == my_si_public_key.mod()
|
||||
|
||||
assert len(my_root_certificate.public_key().mod()) == 1024
|
||||
assert len(my_ca_certificate.public_key().mod()) == 1024
|
||||
assert len(my_si_certificate.public_key().mod()) == 512
|
||||
|
||||
#assert my_si_certificate.public_key().mod() != my_si_public_key.mod()
|
||||
|
||||
my_root_certificate.public_key().raw().verify(
|
||||
my_ca_certificate.raw().signature,
|
||||
my_ca_certificate.raw().tbs_certificate_bytes,
|
||||
PKCS1v15(),
|
||||
SHA256(),
|
||||
)
|
||||
my_ca_certificate.public_key().raw().verify(
|
||||
my_si_certificate.raw().signature,
|
||||
my_si_certificate.raw().tbs_certificate_bytes,
|
||||
PKCS1v15(),
|
||||
SHA256(),
|
||||
)
|
||||
|
||||
|
||||
def test_index():
|
||||
response = client.get('/')
|
||||
assert response.status_code == 200
|
||||
@ -76,9 +107,9 @@ def test_config():
|
||||
|
||||
|
||||
def test_config_root_ca():
|
||||
response = client.get('/-/config/root-ca')
|
||||
response = client.get('/-/config/root-certificate')
|
||||
assert response.status_code == 200
|
||||
assert response.content.decode('utf-8') == my_root_certificate.pem().decode('utf-8')
|
||||
assert response.content.decode('utf-8').strip() == my_root_certificate.pem().decode('utf-8').strip()
|
||||
|
||||
|
||||
def test_readme():
|
||||
@ -103,7 +134,17 @@ def test_config_token():
|
||||
assert response.status_code == 200
|
||||
|
||||
nv_response_certificate_configuration = response.json().get('certificateConfiguration')
|
||||
|
||||
nv_ca_chain = nv_response_certificate_configuration.get('caChain')[0].encode('utf-8')
|
||||
nv_ca_chain = Cert(nv_ca_chain)
|
||||
|
||||
nv_response_public_cert = nv_response_certificate_configuration.get('publicCert').encode('utf-8')
|
||||
nv_response_public_key = nv_response_certificate_configuration.get('publicKey')
|
||||
|
||||
nv_si_certificate = Cert(nv_response_public_cert)
|
||||
assert nv_si_certificate.public_key().mod() == nv_response_public_key.get('mod')[0]
|
||||
assert nv_si_certificate.authority_key_identifier() == nv_ca_chain.subject_key_identifier()
|
||||
|
||||
nv_jwt_decode_key = jwk.construct(nv_response_public_cert, algorithm=ALGORITHMS.RS256)
|
||||
|
||||
nv_response_config_token = response.json().get('configToken')
|
||||
@ -116,8 +157,8 @@ def test_config_token():
|
||||
|
||||
nv_si_public_key_configuration = payload.get('service_instance_public_key_configuration')
|
||||
nv_si_public_key_me = nv_si_public_key_configuration.get('service_instance_public_key_me')
|
||||
# assert nv_si_public_key_me.get('mod') == 1 #nv_si_public_key_mod
|
||||
assert len(nv_si_public_key_me.get('mod')) == 512
|
||||
|
||||
assert len(nv_si_public_key_me.get('mod')) == 512 # nv_si_public_key_mod
|
||||
assert nv_si_public_key_me.get('exp') == 65537 # nv_si_public_key_exp
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user