Skip to content

Commit

Permalink
fix(log-production): fixes bug on epoch limit and log production
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijrajsharma committed Dec 3, 2024
1 parent 3f48e8c commit bba0dfc
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 28 deletions.
11 changes: 9 additions & 2 deletions backend/aiproject/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,15 @@


# Limiter
EPOCHS_LIMIT = env("EPOCHS_LIMIT", default=30)
BATCH_SIZE_LIMIT = env("BATCH_SIZE_LIMIT", default=8)

## YOLO
YOLO_EPOCHS_LIMIT = env("YOLO_EPOCHS_LIMIT", default=200)
YOLO_BATCH_SIZE_LIMIT = env("YOLO_BATCH_SIZE_LIMIT", default=8)

## RAMP
RAMP_EPOCHS_LIMIT = env("RAMP_EPOCHS_LIMIT", default=40)
RAMP_BATCH_SIZE_LIMIT = env("RAMP_BATCH_SIZE_LIMIT", default=8)

TRAINING_WORKSPACE_DOWNLOAD_LIMIT = env(
"TRAINING_WORKSPACE_DOWNLOAD_LIMIT", default=200
)
Expand Down
49 changes: 34 additions & 15 deletions backend/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@
from django.utils import timezone
from predictor import download_imagery, get_start_end_download_coords

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)


logger = logging.getLogger(__name__)
logger.propagate = False


# from core.serializers import LabelFileSerializer

Expand Down Expand Up @@ -363,15 +370,27 @@ def yolo_model_training(
os.makedirs(output_path)

shutil.copyfile(output_model_path, os.path.join(output_path, "checkpoint.pt"))
shutil.copyfile(os.path.join(os.path.dirname(output_model_path),'best.onnx'), os.path.join(output_path, "checkpoint.onnx"))
shutil.copyfile(
os.path.join(os.path.dirname(output_model_path), "best.onnx"),
os.path.join(output_path, "checkpoint.onnx"),
)
# shutil.copyfile(os.path.dirname(output_model_path,'checkpoint.tflite'), os.path.join(output_path, "checkpoint.tflite"))

shutil.copytree(preprocess_output, os.path.join(output_path, "preprocessed"))
os.makedirs(os.path.join(output_path,model),exist_ok=True)
os.makedirs(os.path.join(output_path, model), exist_ok=True)

shutil.copytree(os.path.join(yolo_data_dir,'images'), os.path.join(output_path,model, "images"))
shutil.copytree(os.path.join(yolo_data_dir,'labels'), os.path.join(output_path,model, "labels"))
shutil.copyfile(os.path.join(yolo_data_dir,'yolo_dataset.yaml'), os.path.join(output_path,model, "yolo_dataset.yaml"))
shutil.copytree(
os.path.join(yolo_data_dir, "images"),
os.path.join(output_path, model, "images"),
)
shutil.copytree(
os.path.join(yolo_data_dir, "labels"),
os.path.join(output_path, model, "labels"),
)
shutil.copyfile(
os.path.join(yolo_data_dir, "yolo_dataset.yaml"),
os.path.join(output_path, model, "yolo_dataset.yaml"),
)

graph_output_path = os.path.join(
pathlib.Path(os.path.dirname(output_model_path)).parent, "iou_chart.png"
Expand Down Expand Up @@ -454,22 +473,22 @@ def train_model(
if training_instance.task_id is None or training_instance.task_id.strip() == "":
training_instance.task_id = train_model.request.id
training_instance.save()
log_file = os.path.join(
settings.LOG_PATH, f"run_{train_model.request.id}_log.txt"
)

log_file = os.path.join(settings.LOG_PATH, f"run_{train_model.request.id}.log")

if model_instance.base_model == "YOLO_V8_V1" and settings.YOLO_HOME is None:
raise ValueError("YOLO Home is not configured")
elif model_instance.base_model != "YOLO_V8_V1" and settings.RAMP_HOME is None:
raise ValueError("Ramp Home is not configured")

try:
with open(log_file, "w") as f:
# redirect stdout to the log file
with open(log_file, "a") as f:
# redirect stdout to the log file
sys.stdout = f
training_input_image_source, aoi_serializer, serialized_field = prepare_data(
training_instance, dataset_id, feedback, zoom_level, source_imagery
logging.info("Training Started")
training_input_image_source, aoi_serializer, serialized_field = (
prepare_data(
training_instance, dataset_id, feedback, zoom_level, source_imagery
)
)

if model_instance.base_model in ("YOLO_V8_V1", "YOLO_V8_V2"):
Expand Down Expand Up @@ -499,7 +518,7 @@ def train_model(
input_boundary_width,
)

logger.info(f"Training task {training_id} completed successfully")
logging.info(f"Training task {training_id} completed successfully")
return response

except Exception as ex:
Expand Down
33 changes: 22 additions & 11 deletions backend/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,25 @@ def create(self, validated_data):

epochs = validated_data["epochs"]
batch_size = validated_data["batch_size"]
if model.base_model == "RAMP":
if epochs > settings.RAMP_EPOCHS_LIMIT:
raise ValidationError(
f"Epochs can't be greater than {settings.RAMP_EPOCHS_LIMIT} on this server"
)
if batch_size > settings.RAMP_BATCH_SIZE_LIMIT:
raise ValidationError(
f"Batch size can't be greater than {settings.RAMP_BATCH_SIZE_LIMIT} on this server"
)
if model.base_model in ["YOLO_V8_V1","YOLO_V8_V2"]:

if epochs > settings.EPOCHS_LIMIT:
raise ValidationError(
f"Epochs can't be greater than {settings.EPOCHS_LIMIT} on this server"
)
if batch_size > settings.BATCH_SIZE_LIMIT:
raise ValidationError(
f"Batch size can't be greater than {settings.BATCH_SIZE_LIMIT} on this server"
)

if epochs > settings.YOLO_EPOCHS_LIMIT:
raise ValidationError(
f"Epochs can't be greater than {settings.YOLO_EPOCHS_LIMIT} on this server"
)
if batch_size > settings.YOLO_BATCH_SIZE_LIMIT:
raise ValidationError(
f"Batch size can't be greater than {settings.YOLO_BATCH_SIZE_LIMIT} on this server"
)
user = self.context["request"].user
validated_data["user"] = user
# create the model instance
Expand Down Expand Up @@ -553,11 +562,13 @@ def run_task_status(request, run_id: str):
}
)
elif task_result.state == "PENDING" or task_result.state == "STARTED":
log_file = os.path.join(settings.LOG_PATH, f"run_{run_id}_log.txt")
log_file = os.path.join(settings.LOG_PATH, f"run_{run_id}.log")
try:
# read the last 10 lines of the log file
cmd = ["tail", "-n", str(settings.LOG_LINE_STREAM_TRUNCATE_VALUE), log_file]
# print(cmd)
output = subprocess.check_output(
["tail", "-n", settings.LOG_LINE_STREAM_TRUNCATE_VALUE, log_file]
cmd
).decode("utf-8")
except Exception as e:
output = str(e)
Expand Down

0 comments on commit bba0dfc

Please sign in to comment.