-
Notifications
You must be signed in to change notification settings - Fork 44
/
backend_app.py
47 lines (40 loc) · 1.34 KB
/
backend_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from flask import Flask, request
from flask import render_template
from flask_cors import CORS
from get_model_graph import get_model_graph
from backend_settings import avaliable_hardwares,avaliable_model_ids
import argparse
app = Flask(__name__)
cors = CORS(app, resources={r"/*": {"origins": "*"}})
@app.route("/")
def index():
return "backend server ready."
@app.route("/get_graph", methods=["POST"])
def get_graph():
inference_config = request.json["inference_config"]
nodes, edges, total_results, hardware_info = get_model_graph(
request.json["model_id"],
request.json["hardware"],
None,
inference_config,
)
return {
"nodes": nodes,
"edges": edges,
"total_results": total_results,
"hardware_info": hardware_info,
}
@app.route("/get_avaliable", methods=["GET"])
def get_avaliable():
return {
"avaliable_hardwares": avaliable_hardwares,
"avaliable_model_ids": avaliable_model_ids,
}
if __name__ == "__main__":
parser=argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--local", action="store_true")
parser.add_argument("--debug", action="store_true")
args=parser.parse_args()
host="127.0.0.1" if args.local else "0.0.0.0"
app.run(debug=args.debug,host=host,port=args.port)