Skip to content

Commit

Permalink
fix(tasks): enhanced model training zipping by adding input image pat…
Browse files Browse the repository at this point in the history
…h and improving file copying
  • Loading branch information
kshitijrajsharma committed Dec 6, 2024
1 parent bba0dfc commit eb0eb4f
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions backend/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ def ramp_model_training(
shutil.rmtree(output_path)
shutil.copytree(final_model_path, os.path.join(output_path, "checkpoint.tf"))
shutil.copytree(preprocess_output, os.path.join(output_path, "preprocessed"))
shutil.copytree(
model_input_image_path, os.path.join(output_path, "preprocessed", "input")
)

graph_output_path = f"{base_path}/train/graphs"
shutil.copytree(graph_output_path, os.path.join(output_path, "graphs"))
Expand Down Expand Up @@ -374,11 +377,30 @@ def yolo_model_training(
os.path.join(os.path.dirname(output_model_path), "best.onnx"),
os.path.join(output_path, "checkpoint.onnx"),
)
shutil.copyfile(
os.path.join(os.path.dirname(output_model_path), "best.onnx"),
os.path.join(output_path, "checkpoint.onnx"),
)
# shutil.copyfile(os.path.dirname(output_model_path,'checkpoint.tflite'), os.path.join(output_path, "checkpoint.tflite"))

shutil.copytree(preprocess_output, os.path.join(output_path, "preprocessed"))
shutil.copytree(
model_input_image_path, os.path.join(output_path, "preprocessed", "input")
)
os.makedirs(os.path.join(output_path, model), exist_ok=True)

shutil.copytree(
os.path.join(yolo_data_dir, "images"),
os.path.join(output_path, model, "images"),
)
shutil.copytree(
os.path.join(yolo_data_dir, "labels"),
os.path.join(output_path, model, "labels"),
)
shutil.copyfile(
os.path.join(yolo_data_dir, "yolo_dataset.yaml"),
os.path.join(output_path, model, "yolo_dataset.yaml"),
)
shutil.copytree(
os.path.join(yolo_data_dir, "images"),
os.path.join(output_path, model, "images"),
Expand Down Expand Up @@ -473,18 +495,17 @@ def train_model(
if training_instance.task_id is None or training_instance.task_id.strip() == "":
training_instance.task_id = train_model.request.id
training_instance.save()
log_file = os.path.join(settings.LOG_PATH, f"run_{train_model.request.id}.log")
log_file = os.path.join(settings.LOG_PATH, f"run_{train_model.request.id}_log.txt")

if model_instance.base_model == "YOLO_V8_V1" and settings.YOLO_HOME is None:
raise ValueError("YOLO Home is not configured")
elif model_instance.base_model != "YOLO_V8_V1" and settings.RAMP_HOME is None:
raise ValueError("Ramp Home is not configured")

try:
with open(log_file, "a") as f:
with open(log_file, "w") as f:
# redirect stdout to the log file
sys.stdout = f
logging.info("Training Started")
training_input_image_source, aoi_serializer, serialized_field = (
prepare_data(
training_instance, dataset_id, feedback, zoom_level, source_imagery
Expand Down

0 comments on commit eb0eb4f

Please sign in to comment.