diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index 43b8359a1f9..d257b4fc017 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -63,6 +63,66 @@ def make_app(name: str, router: APIRouter) -> FastAPI: } +def create_app( + name: str, + node_type: NodeType, + node_side_type: NodeSideType, + processes: int, + reset: bool, + local_db: bool, +) -> FastAPI: + # Print variables for debugging + print("*" * 50, flush=True) + print("Starting uvicorn app with the following settings:", flush=True) + print(f"NODE_NAME: {name}", flush=True) + print(f"PROCESSES: {processes}", flush=True) + print(f"RESET: {reset}", flush=True) + print(f"LOCAL_DB: {local_db}", flush=True) + print(f"NODE_TYPE: {node_type}", flush=True) + print(f"NODE_SIDE_TYPE: {node_side_type}", flush=True) + print("*" * 50, flush=True) + + worker_type = worker_classes[node_type] + worker = worker_type.named( + name=name, + processes=processes, + reset=reset, + local_db=local_db, + node_type=node_type, + node_side_type=node_side_type, + enable_warnings=False, + migrate=True, + in_memory_workers=True, + queue_port=None, + create_producer=False, + n_consumers=0, + association_request_auto_approval=False, + background_tasks=False, + ) + + router = make_routes(worker=worker) + app = make_app(worker.name, router=router) + return app + + +def app_factory() -> FastAPI: + name = os.getenv("NODE_NAME") + node_type = NodeType(os.getenv("NODE_TYPE", "domain")) + node_side_type = NodeSideType(os.getenv("NODE_SIDE_TYPE", "high")) + processes = int(os.getenv("PROCESSES", 1)) + local_db = os.getenv("LOCAL_DB", "True") == "True" + reset = os.getenv("RESET", "False") == "False" + + return create_app( + name=name, + node_type=node_type, + node_side_type=node_side_type, + processes=processes, + reset=reset, + local_db=local_db, + ) + + def run_uvicorn( name: str, node_type: Enum, diff --git a/packages/syft/src/syft/node/server_management.py b/packages/syft/src/syft/node/server_management.py new file mode 100644 index 00000000000..eba02a18afd --- /dev/null +++ b/packages/syft/src/syft/node/server_management.py @@ -0,0 +1,91 @@ +# stdlib +import os +import subprocess + +# third party +import uvicorn + +# relative +from ..abstract_node import NodeSideType + +# from .server import app_factory +from ..orchestra import DeploymentType +from ..orchestra import NodeHandle +from .node import NodeType + +# List storing all reloadable servers launched +process_list = [] + + +def start_reloadable_server( + name: str = "testing-node", + node_type: str = "domain", + node_side_type: str = "high", + port: int = 9081, + processes: int = 1, + local_db: bool = True, + reset: bool = False, +) -> NodeHandle: + os.environ["NODE_NAME"] = name + os.environ["NODE_TYPE"] = node_type + os.environ["NODE_SIDE_TYPE"] = node_side_type + os.environ["PORT"] = str(port) + os.environ["PROCESSES"] = str(processes) + os.environ["LOCAL_DB"] = str(local_db) + os.environ["RESET"] = str(reset) + + command = ["python", "-m", "syft.node.server_management"] + process = subprocess.Popen(command) + process_list.append(process) + print("*" * 50, flush=True) + print(f"Uvicorn server running on port {port} with PID: {process.pid}", flush=True) + print("*" * 50, flush=True) + + # Since the servers take a second to run, adding this wait so + # that notebook commands can run one after the other. + # stdlib + from time import sleep + + sleep(6) + + def stop() -> None: + process.terminate() + process.wait() + if process in process_list: + process_list.remove(process) + print("*" * 50, flush=True) + print(f"Uvicorn server with PID: {process.pid} stopped.", flush=True) + print("*" * 50, flush=True) + + # Return this object: + return NodeHandle( + node_type=NodeType(node_type), + deployment_type=DeploymentType.PYTHON, + name=name, + port=port, + url="http://localhost", + node_side_type=NodeSideType(node_side_type), + shutdown=stop, + ) + + +def stop_all_reloadable_servers() -> None: + for process in process_list: + process.terminate() + process.wait() + process_list.clear() + print("All Uvicorn servers stopped.") + + +if __name__ == "__main__": + current_file_path = os.path.dirname(os.path.abspath(__file__)) + reload_dirs = os.path.abspath(os.path.join(current_file_path, "../../")) + + uvicorn.run( + "syft.node.server:app_factory", + host="0.0.0.0", + port=int(os.getenv("PORT", 9081)), + reload=True, + factory=True, + reload_dirs=[reload_dirs], + )