Skip to content

Commit

Permalink
A potential solution
Browse files Browse the repository at this point in the history
  • Loading branch information
snwagh committed Jun 25, 2024
1 parent c30424e commit 05000a1
Showing 1 changed file with 80 additions and 10 deletions.
90 changes: 80 additions & 10 deletions packages/syft/src/syft/node/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def make_app(name: str, router: APIRouter) -> FastAPI:
}


def run_reloadable_app() -> None:
# Read environment variables
name = os.getenv("NODE_NAME", "testing-node")
processes = int(os.getenv("PROCESSES", "1"))
reset = os.getenv("RESET", "False").lower() in ("true")
local_db = os.getenv("LOCAL_DB", "True").lower() in ("true")
node_type = NodeType(os.getenv("NODE_TYPE", NodeType.DOMAIN))
node_side_type = NodeSideType(os.getenv("NODE_SIDE_TYPE", NodeSideType.HIGH_SIDE))

process_list = []
def create_app(
name: str,
node_type: NodeType,
node_side_type: NodeSideType,
processes: int,
reset: bool,
local_db: bool
) -> FastAPI:
# Print variables for debugging
print("*" * 50)
print("Starting uvicorn app with the following settings:")
Expand All @@ -83,7 +83,6 @@ def run_reloadable_app() -> None:
print(f"NODE_SIDE_TYPE: {node_side_type}")
print("*" * 50)

# Assuming worker_classes, make_routes, and make_app are defined elsewhere
worker_type = worker_classes[node_type]
worker = worker_type.named(
name=name,
Expand All @@ -106,6 +105,77 @@ def run_reloadable_app() -> None:
app = make_app(worker.name, router=router)
return app

def app_factory():
name = os.getenv("NODE_NAME")
node_type = NodeType(os.getenv("NODE_TYPE", "domain"))
node_side_type = NodeSideType(os.getenv("NODE_SIDE_TYPE", "high"))
processes = int(os.getenv("PROCESSES", 1))
local_db = os.getenv("LOCAL_DB", "True") == "True"
reset = os.getenv("RESET", "False") == "False"

return create_app(
name=name,
node_type=node_type,
node_side_type=node_side_type,
processes=processes,
reset=reset,
local_db=local_db
)

def start_uvicorn_server(
name: str = "testing-node",
node_type: str = "domain",
node_side_type: str = "high",
port: int = 9081,
processes: int = 1,
local_db: bool = True,
reset: bool = False,
):
os.environ["NODE_NAME"] = name
os.environ["NODE_TYPE"] = node_type
os.environ["NODE_SIDE_TYPE"] = node_side_type
os.environ["PORT"] = str(port)
os.environ["PROCESSES"] = str(processes)
os.environ["LOCAL_DB"] = str(local_db)
os.environ["RESET"] = str(reset)

command = ["python", "-m", "syft.node.server"]
process = subprocess.Popen(command)
process_list.append(process)
print(f"Uvicorn server running on port {port} with PID: {process.pid}")

from syft.orchestra import NodeHandle, DeploymentType

# Return this object:
return NodeHandle(
node_type=NodeType(node_type),
deployment_type=DeploymentType.PYTHON,
name=name,
port=port,
url="http://localhost",
node_side_type=NodeSideType(node_side_type),
# shutdown=stop,
)
# return process

def stop_all_uvicorn_servers():
for process in process_list:
process.terminate()
process.wait()
process_list.clear()
print("All Uvicorn servers stopped.")

if __name__ == "__main__":
current_file_path = os.path.dirname(os.path.abspath(__file__))
reload_dirs = os.path.abspath(os.path.join(current_file_path, "../../"))

uvicorn.run("syft.node.server:app_factory",
host="0.0.0.0",
port=int(os.getenv("PORT", 9081)),
reload=True,
factory=True,
reload_dirs=[reload_dirs],
)

def run_uvicorn(
name: str,
Expand Down

0 comments on commit 05000a1

Please sign in to comment.