Skip to content

Commit

Permalink
Merge branch 'main' into feature/precision_improvement
Browse files Browse the repository at this point in the history
# Conflicts:
#	training/train.py
  • Loading branch information
Flippchen committed Feb 8, 2024
2 parents b52c613 + 297e5f7 commit 4732913
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 6 deletions.
Binary file added assets/banner_gif.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion training/old/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@

# Train model
epochs = 20
with tf.device('/GPU:0'):
device = tf.test.gpu_device_name() if tf.test.is_gpu_available() else '/CPU:0'
print("Using Device:", device)
with tf.device(device):
history = model.fit(
train_ds,
validation_data=val_ds,
Expand Down
4 changes: 3 additions & 1 deletion training/old/with_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@

# Train model
epochs = 20
with tf.device('/GPU:1'):
device = tf.test.gpu_device_name() if tf.test.is_gpu_available() else '/CPU:0'
print("Using Device:", device)
with tf.device(device):
history = model.fit(
train_ds,
validation_data=val_ds,
Expand Down
4 changes: 3 additions & 1 deletion training/old/without_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@

# Train model
epochs = 20
with tf.device('/GPU:1'):
device = tf.test.gpu_device_name() if tf.test.is_gpu_available() else '/CPU:0'
print("Using Device:", device)
with tf.device(device):
history = model.fit(
train_ds,
validation_data=val_ds,
Expand Down
4 changes: 3 additions & 1 deletion training/pre_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@

# Train model
epochs = 15
with tf.device('/GPU:0'):
device = tf.test.gpu_device_name() if tf.test.is_gpu_available() else '/CPU:0'
print("Using Device:", device)
with tf.device(device):
history = model.fit(
train_ds,
validation_data=val_ds,
Expand Down
8 changes: 6 additions & 2 deletions training/train.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# This file contains the code for training a model with data augmentation and a pretrained base.
# Import libraries
import keras
from keras import layers
from keras.models import Sequential
from keras.applications import EfficientNetV2B1
from utilities.tools import *
from utilities.tools import get_data_path_addon, get_base_path, suppress_tf_warnings, load_dataset, show_augmented_batch, create_augmentation_layer, plot_model_score, show_sample_batch, show_batch_shape
from utilities.discord_callback import DiscordCallback
from keras.optimizers import AdamW
from keras.regularizers import l1_l2
from keras.callbacks import EarlyStopping, ModelCheckpoint
import os
import random

import tensorflow as tf
# Ignore warnings
import warnings

Expand All @@ -32,9 +35,10 @@
# Set seed for reproducibility
random_seed = True
# Config
base_path = get_base_path()
path_addon = get_data_path_addon(model_type)
config = {
"path": f"C:/Users\phili/.keras/datasets/resized_DVM/{path_addon}",
"path": f"{base_path}/{path_addon}",
"batch_size": 32,
"img_height": img_height,
"img_width": img_width,
Expand Down
15 changes: 15 additions & 0 deletions utilities/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras import layers
import os
import logging
import platform


def load_dataset(path: str, batch_size: int, img_height: int, img_width: int, seed: int) -> tuple[tf.data.Dataset, tf.data.Dataset, list]:
Expand Down Expand Up @@ -317,3 +318,17 @@ def get_data_path_addon(name: str) -> str:
return "pre_filter"
else:
raise ValueError("Invalid model name")


def get_base_path():
# Determine the base path depending on the operating system
if platform.system() == 'Windows':
base_path = r"C:/Users\phili/.keras/datasets/resized_DVM"
elif platform.system() == 'Linux':
base_path = "/home/luke/datasets/"
elif platform.system() == 'Darwin': # Darwin is the system name for macOS
base_path = "/Users/flippchen/datasets/"
else:
raise ValueError("Operating system not supported.")

return base_path

0 comments on commit 4732913

Please sign in to comment.