Skip to content

Commit

Permalink
Add checkpoint.tflite copy to ramp_model_training and clean up yolo_m…
Browse files Browse the repository at this point in the history
…odel_training
  • Loading branch information
kshitijrajsharma committed Dec 13, 2024
1 parent 3e99467 commit a7befc7
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions backend/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,10 @@ def ramp_model_training(
if os.path.exists(output_path):
shutil.rmtree(output_path)
shutil.copytree(final_model_path, os.path.join(output_path, "checkpoint.tf"))
shutil.copyfile(
os.path.join(os.path.dirname(final_model_path), "checkpoint.tflite"),
os.path.join(output_path, "checkpoint.tflite"),
)
shutil.copytree(preprocess_output, os.path.join(output_path, "preprocessed"))
shutil.copytree(
model_input_image_path, os.path.join(output_path, "preprocessed", "input")
Expand Down Expand Up @@ -400,11 +404,6 @@ def yolo_model_training(
os.path.join(os.path.dirname(output_model_path), "best.onnx"),
os.path.join(output_path, "checkpoint.onnx"),
)
shutil.copyfile(
os.path.join(os.path.dirname(output_model_path), "best.onnx"),
os.path.join(output_path, "checkpoint.onnx"),
)
# shutil.copyfile(os.path.dirname(output_model_path,'checkpoint.tflite'), os.path.join(output_path, "checkpoint.tflite"))

shutil.copytree(preprocess_output, os.path.join(output_path, "preprocessed"))
shutil.copytree(
Expand Down

0 comments on commit a7befc7

Please sign in to comment.