Skip to content

Commit

Permalink
Merge pull request #62 from demml/fix-streaming
Browse files Browse the repository at this point in the history
Fix streaming

Former-commit-id: d978bc8cd1296ace11aa73012b9f8fbee74a1dbd [formerly 9163a46]
Former-commit-id: dc841b8
  • Loading branch information
thorrester authored Nov 1, 2024
2 parents d032dd5 + 2803697 commit 28954a9
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 69 deletions.
30 changes: 0 additions & 30 deletions opsml/app/core/gunicorn.py

This file was deleted.

13 changes: 0 additions & 13 deletions opsml/app/gunicorn_conf.py

This file was deleted.

14 changes: 7 additions & 7 deletions opsml/app/routes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class UserDeleted(BaseModel):
router = APIRouter()


async def get_current_user(
def get_current_user(
request: Request,
token: Annotated[str, Depends(oauth2_scheme)],
) -> User:
Expand Down Expand Up @@ -98,7 +98,7 @@ async def get_current_user(
return user


async def get_current_active_user(
def get_current_active_user(
current_user: Annotated[User, Depends(get_current_user)],
) -> User:
if not current_user.is_active:
Expand All @@ -107,7 +107,7 @@ async def get_current_active_user(


@router.post("/auth/token")
async def login_for_access_token(
def login_for_access_token(
request: Request,
response: Response,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
Expand Down Expand Up @@ -149,7 +149,7 @@ async def login_for_access_token(


@router.get("/auth/token/rotate")
async def create_refresh_token(
def create_refresh_token(
request: Request,
response: Response,
refresh_token: Annotated[Union[str, None], Cookie()] = None,
Expand All @@ -175,7 +175,7 @@ async def create_refresh_token(
try:
# check user
auth_db: ServerAuthRegistry = request.app.state.auth_db
user = await get_current_user(request, refresh_token)
user = get_current_user(request, refresh_token)

# create new access token
refresh_token = auth_db.create_access_token(user, minutes=60)
Expand All @@ -193,7 +193,7 @@ async def create_refresh_token(


@router.get("/auth/token/refresh")
async def get_refresh_from_cookie(
def get_refresh_from_cookie(
request: Request,
response: Response,
refresh_token: Annotated[Union[str, None], Cookie()] = None,
Expand All @@ -207,7 +207,7 @@ async def get_refresh_from_cookie(
headers={"WWW-Authenticate": "Bearer"},
)

user = await get_current_user(request, refresh_token)
user = get_current_user(request, refresh_token)
logger.info("Refreshing token for user: {}", user.username)

# create new access token
Expand Down
8 changes: 4 additions & 4 deletions opsml/app/routes/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@


@router.get("/data/download", name="download_data")
async def download_data(request: Request, uid: str) -> StreamingResponse:
def download_data(request: Request, uid: str) -> StreamingResponse:
"""Downloads data associated with a datacard"""

registry: CardRegistry = request.app.state.registries.data
datacard = cast(DataCard, registry.load_card(uid=uid))
load_path = Path(datacard.uri / SaveName.DATA.value).with_suffix(datacard.interface.data_suffix)
return await download_artifacts_ui(request, str(load_path))
return download_artifacts_ui(request, str(load_path))


@router.get("/data/download/profile", name="download_data_profile")
async def download_data_profile(
def download_data_profile(
request: Request,
uid: str,
) -> StreamingResponse:
Expand All @@ -49,7 +49,7 @@ async def download_data_profile(
registry: CardRegistry = request.app.state.registries.data
datacard = cast(DataCard, registry.load_card(uid=uid))
load_path = Path(datacard.uri / SaveName.DATA_PROFILE.value).with_suffix(Suffix.HTML.value)
return await download_file(request, str(load_path))
return download_file(request, str(load_path))


@router.post("/data/card", name="data_card", response_model=DataCardMetadata)
Expand Down
28 changes: 16 additions & 12 deletions opsml/app/routes/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import streaming_form_data
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import StreamingResponse
from starlette.concurrency import run_in_threadpool
from starlette.requests import ClientDisconnect
from streaming_form_data import StreamingFormDataParser
from streaming_form_data.validators import MaxSizeValidator
Expand Down Expand Up @@ -114,7 +113,7 @@ async def upload_file(request: Request) -> Dict[str, str]: # pragma: no cover


@router.get("/files/download", name="download_file")
async def download_file(request: Request, path: str) -> StreamingResponse:
def download_file(request: Request, path: str) -> StreamingResponse:
"""Downloads a file
Args:
Expand All @@ -130,8 +129,11 @@ async def download_file(request: Request, path: str) -> StreamingResponse:
storage_client: StorageClientBase = request.app.state.storage_client
try:
file_path = Path(swap_opsml_root(request, Path(path)))
file_iterator = await run_in_threadpool(storage_client.iterfile, file_path, config.download_chunk_size)
return StreamingResponse(file_iterator, media_type="application/octet-stream")

return StreamingResponse(
storage_client.iterfile(file_path, config.download_chunk_size),
media_type="application/octet-stream",
)

except Exception as error:
logger.error("Server: Error downloading file {}", path)
Expand All @@ -141,7 +143,7 @@ async def download_file(request: Request, path: str) -> StreamingResponse:
) from error


async def download_dir(request: Request, path: Path) -> StreamingResponse:
def download_dir(request: Request, path: Path) -> StreamingResponse:
"""Downloads a file
Args:
Expand Down Expand Up @@ -170,14 +172,16 @@ async def download_dir(request: Request, path: Path) -> StreamingResponse:
curr_rpath = Path(file_)
curr_lpath = lpath / curr_rpath.relative_to(rpath)
logger.info("Server: Downloading {} to {}", curr_rpath, curr_lpath)
await run_in_threadpool(storage_client.get, curr_rpath, curr_lpath)
storage_client.get(curr_rpath, curr_lpath)
zip_filepath = zipfile / curr_rpath.relative_to(rpath)
temp_zip.write(curr_lpath, zip_filepath)

logger.info("Server: Sending zip file for {}", path)
iter_buffer = await run_in_threadpool(storage_client.iterbuffer, zip_io, config.download_chunk_size)

return StreamingResponse(iter_buffer, media_type="application/x-zip-compressed")
return StreamingResponse(
storage_client.iterbuffer(zip_io, config.download_chunk_size),
media_type="application/x-zip-compressed",
)

except Exception as error:
raise HTTPException(
Expand All @@ -187,7 +191,7 @@ async def download_dir(request: Request, path: Path) -> StreamingResponse:


@router.get("/files/download/ui", name="download_artifacts")
async def download_artifacts_ui(request: Request, path: str) -> StreamingResponse:
def download_artifacts_ui(request: Request, path: str) -> StreamingResponse:
"""Downloads a file
Args:
Expand All @@ -200,8 +204,8 @@ async def download_artifacts_ui(request: Request, path: str) -> StreamingRespons
Streaming file response
"""
if Path(path).suffix == "":
return await download_dir(request, Path(path))
return await download_file(request, path)
return download_dir(request, Path(path))
return download_file(request, path)


@router.get("/files/list", name="list_files")
Expand Down Expand Up @@ -395,7 +399,7 @@ def get_file_to_view(request: Request, path: str) -> FileViewResponse:


@router.post("/files/readme", name="create_readme")
async def create_readme(
def create_readme(
request: Request,
payload: ReadMeRequest,
) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions opsml/app/routes/healthcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


@router.get("/healthcheck", response_model=HealthCheckResult, name="healthcheck")
def get_healthcheck() -> HealthCheckResult:
async def get_healthcheck() -> HealthCheckResult:
return HealthCheckResult(is_alive=True)


Expand All @@ -21,12 +21,13 @@ async def debug() -> DebugResponse:
url=config.opsml_tracking_uri,
storage=config.opsml_storage_uri,
app_env=config.app_env,
app_version=config.app_version,
)


@router.get(
"/error",
description="An endpoint that will return a 500 error for debugging and alert testing",
)
def get_error() -> None:
async def get_error() -> None:
raise HTTPException(status_code=500)
1 change: 1 addition & 0 deletions opsml/app/routes/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class DebugResponse(BaseModel):
url: str
storage: str
app_env: str
app_version: str


class StorageSettingsResponse(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion opsml/app/routes/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_runcard(request: Request, payload: CardRequest) -> RunCard:


@router.get("/runs/graphs", name="graphs")
async def get_graph_plots(request: Request, repository: str, name: str, version: str) -> Dict[str, Any]:
def get_graph_plots(request: Request, repository: str, name: str, version: str) -> Dict[str, Any]:
"""Method for loading plots for a run
Args:
Expand Down
2 changes: 2 additions & 0 deletions opsml/settings/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pydantic_settings import BaseSettings

from opsml.types import StorageSystem
from opsml.version import __version__


class OpsmlAuthSettings(BaseModel):
Expand All @@ -23,6 +24,7 @@ class OpsmlAuthSettings(BaseModel):
class OpsmlConfig(BaseSettings):
app_name: str = "opsml"
app_env: str = "development"
app_version: str = __version__

opsml_storage_uri: str = "./opsml_registries"
opsml_tracking_uri: str = "sqlite:///opsml.db"
Expand Down

0 comments on commit 28954a9

Please sign in to comment.