Skip to content

Commit

Permalink
Merge pull request #8659 from OpenMined/snwagh/attestation-service
Browse files Browse the repository at this point in the history
Attestation service in Syft
  • Loading branch information
snwagh authored May 6, 2024
2 parents b6fcea9 + 7c8c19a commit de906b9
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 15 deletions.
104 changes: 104 additions & 0 deletions packages/grid/enclave/attestation/enclave-development.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,107 @@ print ("[RemoteGPUTest] node name :", client.get_name())
client.add_verifier(attestation.Devices.GPU, attestation.Environment.REMOTE, NRAS_URL, "")
client.attest()
```

### Instructions for using helm charts

- The attestation container runs inside the backend pod (so backend pod has two containers now). However, in order to run the attestation container, you need to uncomment the attestation flags in `packages/grid/helm/values.dev.yaml`
- Next, we run the deployment. Since k3d creates an intermediate layer of nesting, we need to mount some volumes from host to k3d registry. Thus, when launching, use the following tox command `tox -e dev.k8s.start -- --volume /sys/kernel/security:/sys/kernel/security --volume /dev/tmprm0:/dev/tmprm0`
- Finally, note that the GPU privileges/drivers etc. have not been completed so while the GPU attestation endpoints should work, they will not produce the expected tokens. To test the GPU code, follow the steps provided in [For GPU Attestation
](#for-gpu-attestation) to look at the tokens.

### Local Client-side Verification

Use the following function to perform local, client-side verification of tokens. They expire quick.

```python3
def verify_token(token: str, token_type: str):
"""
Verifies a JSON Web Token (JWT) using a public key obtained from a JWKS (JSON Web Key Set) endpoint,
based on the specified type of token ('cpu' or 'gpu'). The function handles two distinct processes
for token verification depending on the type specified:
- 'cpu': Fetches the JWKS from the 'jku' URL specified in the JWT's unverified header,
finds the key by 'kid', and converts the JWK to a PEM format public key for verification.
- 'gpu': Directly uses a fixed JWKS URL to retrieve the keys, finds the key by 'kid', and uses the
'x5c' field to extract a certificate which is then used to verify the token.
Parameters:
token (str): The JWT that needs to be verified.
type (str): Type of the token which dictates the verification process; expected values are 'cpu' or 'gpu'.
Returns:
bool: True if the JWT is successfully verified, False otherwise.
Raises:
Exception: Raises various exceptions internally but catches them to return False, except for
printing error messages related to the specific failures (e.g., key not found, invalid certificate).
Example usage:
verify_token('your.jwt.token', 'cpu')
verify_token('your.jwt.token', 'gpu')
Note:
- The function prints out details about the verification process and errors, if any.
- Ensure that the cryptography and PyJWT libraries are properly installed and updated in your environment.
"""
import jwt
import json
import base64
import requests
from jwt.algorithms import RSAAlgorithm
from cryptography.x509 import load_der_x509_certificate
from cryptography.hazmat.primitives import serialization


# Determine JWKS URL based on the token type
if token_type.lower() == "gpu":
jwks_url = 'https://nras.attestation.nvidia.com/.well-known/jwks.json'
else:
unverified_header = jwt.get_unverified_header(token)
jwks_url = unverified_header['jku']

# Fetch the JWKS from the endpoint
jwks = requests.get(jwks_url).json()

# Get the key ID from the JWT header
header = jwt.get_unverified_header(token)
kid = header['kid']

# Find the key with the matching kid in the JWKS
key = next((item for item in jwks["keys"] if item["kid"] == kid), None)
if not key:
print("Public key not found in JWKS list.")
return False

# Convert the key based on the token type
if token_type.lower() == "gpu" and "x5c" in key:
try:
cert_bytes = base64.b64decode(key['x5c'][0])
cert = load_der_x509_certificate(cert_bytes)
public_key = cert.public_key()
except Exception as e:
print("Failed to process certificate:", str(e))
return False
elif token_type.lower() == "cpu":
try:
public_key = RSAAlgorithm.from_jwk(key)
except Exception as e:
print("Failed to convert JWK to PEM:", str(e))
return False
else:
print("Invalid token_type or key information.")
return False

# Verify the JWT using the public key
try:
payload = jwt.decode(token, public_key, algorithms=[header['alg']], options={"verify_exp": True})
print("JWT Payload:", json.dumps(payload, indent=2))
return True
except jwt.ExpiredSignatureError:
print("JWT token has expired.")
except jwt.InvalidTokenError as e:
print("JWT token signature is invalid:", str(e))

return False
```
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
NRAS_URL = "https://nras.attestation.nvidia.com/v1/attest/gpu"
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from loguru import logger

# relative
from .attestation_models import CPUAttestationResponseModel
from .attestation_models import GPUAttestationResponseModel
from .attestation_models import ResponseModel
from .cpu_attestation import attest_cpu
from .gpu_attestation import attest_gpu
from .models import CPUAttestationResponseModel
from .models import GPUAttestationResponseModel
from .models import ResponseModel

# Logging Configuration
log_level = os.getenv("APP_LOG_LEVEL", "INFO").upper()
Expand All @@ -28,11 +28,11 @@ async def read_root() -> ResponseModel:

@app.get("/attest/cpu", response_model=CPUAttestationResponseModel)
async def attest_cpu_endpoint() -> CPUAttestationResponseModel:
cpu_attest_res = attest_cpu()
return CPUAttestationResponseModel(result=cpu_attest_res)
cpu_attest_res, cpu_attest_token = attest_cpu()
return CPUAttestationResponseModel(result=cpu_attest_res, token=cpu_attest_token)


@app.get("/attest/gpu", response_model=GPUAttestationResponseModel)
async def attest_gpu_endpoint() -> GPUAttestationResponseModel:
gpu_attest_res = attest_gpu()
return GPUAttestationResponseModel(result=gpu_attest_res)
gpu_attest_res, gpu_attest_token = attest_gpu()
return GPUAttestationResponseModel(result=gpu_attest_res, token=gpu_attest_token)
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ class ResponseModel(BaseModel):

class CPUAttestationResponseModel(BaseModel):
result: str
token: str = ""
vendor: str | None = None # Hardware Manufacturer


class GPUAttestationResponseModel(BaseModel):
result: str
token: str = ""
vendor: str | None = None # Hardware Manufacturer
15 changes: 12 additions & 3 deletions packages/grid/enclave/attestation/server/cpu_attestation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from loguru import logger


def attest_cpu() -> str:
def attest_cpu() -> tuple[str, str]:
# Fetch report from Micrsoft Attestation library
cpu_report = subprocess.run(
["/app/AttestationClient"], capture_output=True, text=True
Expand All @@ -14,7 +14,16 @@ def attest_cpu() -> str:
logger.debug(f"Stderr: {cpu_report.stderr}")

logger.info("Attestation Return Code: {}", cpu_report.returncode)
res = "False"
if cpu_report.returncode == 0 and cpu_report.stdout == "true":
return "True"
res = "True"

return "False"
# Fetch token from Micrsoft Attestation library
cpu_token = subprocess.run(
["/app/AttestationClient", "-o", "token"], capture_output=True, text=True
)
logger.debug(f"Stdout: {cpu_token.stdout}")
logger.debug(f"Stderr: {cpu_token.stderr}")

logger.info("Attestation Token Return Code: {}", cpu_token.returncode)
return res, cpu_token.stdout
38 changes: 34 additions & 4 deletions packages/grid/enclave/attestation/server/gpu_attestation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
# stdlib
import io
import re
import sys

# third party
from loguru import logger
from nv_attestation_sdk import attestation

NRAS_URL = "https://nras.attestation.nvidia.com/v1/attest/gpu"
# relative
from .attestation_constants import NRAS_URL


# Function to process captured output to extract the token
def extract_token(captured_value: str) -> str:
match = re.search(r"Entity Attestation Token is (\S+)", captured_value)
if match:
token = match.group(1) # Extract the token, which is in group 1 of the match
return token
else:
return "Token not found"


def attest_gpu() -> str:
def attest_gpu() -> tuple[str, str]:
# Fetch report from Nvidia Attestation SDK
client = attestation.Attestation("Attestation Node")

Expand All @@ -15,7 +31,21 @@ def attest_gpu() -> str:
client.add_verifier(
attestation.Devices.GPU, attestation.Environment.REMOTE, NRAS_URL, ""
)

# Step 1: Redirect stdout
original_stdout = sys.stdout # Save a reference to the original standard output
captured_output = io.StringIO() # Create a StringIO object to capture output
sys.stdout = captured_output # Redirect stdout to the StringIO object

# Step 2: Call the function
gpu_report = client.attest()
logger.info("[RemoteGPUTest] report : {}, {}", gpu_report, type(gpu_report))

return str(gpu_report)
# Step 3: Get the content of captured output and reset stdout
captured_value = captured_output.getvalue()
sys.stdout = original_stdout # Reset stdout to its original state

# Step 4: Extract the token from the captured output
token = extract_token(captured_value)

logger.info("[RemoteGPUTest] report : {}, {}", gpu_report, type(gpu_report))
return str(gpu_report), token
2 changes: 1 addition & 1 deletion packages/grid/enclave/attestation/start.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
set -e
export PATH="/root/.local/bin:${PATH}"

APP_MODULE=server.main:app
APP_MODULE=server.attestation_main:app
APP_LOG_LEVEL=${APP_LOG_LEVEL:-info}
UVICORN_LOG_LEVEL=${UVICORN_LOG_LEVEL:-info}
HOST=${HOST:-0.0.0.0}
Expand Down
6 changes: 6 additions & 0 deletions packages/syft/src/syft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,11 @@ def _orchestra() -> Orchestra:
return Orchestra


@module_property
def hello_baby() -> None:
print("Hello baby!")
print("Welcome to the world. \u2764\ufe0f")


def search(name: str) -> SearchResults:
return Search(_domains()).search(name=name)
2 changes: 2 additions & 0 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from ..service.action.action_store import MongoActionStore
from ..service.action.action_store import SQLiteActionStore
from ..service.api.api_service import APIService
from ..service.attestation.attestation_service import AttestationService
from ..service.blob_storage.service import BlobStorageService
from ..service.code.status_service import UserCodeStatusService
from ..service.code.user_code_service import UserCodeService
Expand Down Expand Up @@ -877,6 +878,7 @@ def _construct_services(self) -> None:
default_services: list[dict] = [
{"svc": ActionService, "store": self.action_store},
{"svc": UserService},
{"svc": AttestationService},
{"svc": WorkerService},
{"svc": SettingsService},
{"svc": DatasetService},
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
ATTESTATION_SERVICE_URL = (
"http://localhost:4455" # Replace with "http://attestation:4455"
)
ATTEST_CPU_ENDPOINT = "/attest/cpu"
ATTEST_GPU_ENDPOINT = "/attest/gpu"
66 changes: 66 additions & 0 deletions packages/syft/src/syft/service/attestation/attestation_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# stdlib
from collections.abc import Callable

# third party
import requests

# relative
from ...serde.serializable import serializable
from ...store.document_store import DocumentStore
from ...util.util import str_to_bool
from ..context import AuthedServiceContext
from ..response import SyftError
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import service_method
from ..user.user_roles import GUEST_ROLE_LEVEL
from .attestation_constants import ATTESTATION_SERVICE_URL
from .attestation_constants import ATTEST_CPU_ENDPOINT
from .attestation_constants import ATTEST_GPU_ENDPOINT


@serializable()
class AttestationService(AbstractService):
"""This service is responsible for getting all sorts of attestations for any client."""

def __init__(self, store: DocumentStore) -> None:
self.store = store

def perform_request(
self, method: Callable, endpoint: str, raw: bool = False
) -> SyftSuccess | SyftError | str:
try:
response = method(f"{ATTESTATION_SERVICE_URL}{endpoint}")
response.raise_for_status()
message = response.json().get("result")
raw_token = response.json().get("token")
if raw:
return raw_token
elif str_to_bool(message):
return SyftSuccess(message=message)
else:
return SyftError(message=message)
except requests.HTTPError:
return SyftError(message=f"{response.json()['detail']}")
except requests.RequestException as e:
return SyftError(message=f"Failed to perform request. {e}")

@service_method(
path="attestation.get_cpu_attestation",
name="get_cpu_attestation",
roles=GUEST_ROLE_LEVEL,
)
def get_cpu_attestation(
self, context: AuthedServiceContext, raw_token: bool = False
) -> str | SyftError | SyftSuccess:
return self.perform_request(requests.get, ATTEST_CPU_ENDPOINT, raw_token)

@service_method(
path="attestation.get_gpu_attestation",
name="get_gpu_attestation",
roles=GUEST_ROLE_LEVEL,
)
def get_gpu_attestation(
self, context: AuthedServiceContext, raw_token: bool = False
) -> str | SyftError | SyftSuccess:
return self.perform_request(requests.get, ATTEST_GPU_ENDPOINT, raw_token)

0 comments on commit de906b9

Please sign in to comment.