Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/precision improvement #92

Merged
merged 17 commits into from
Feb 8, 2024
Merged
6 changes: 4 additions & 2 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_load_dataset(tmp_path):
p2 = d2 / f"img{i + 1}.jpg"
p2.write_text("fake image data")

train_ds, val_ds, class_names = load_dataset(str(d), 2, 32, 32)
train_ds, val_ds, class_names = load_dataset(str(d), 2, 32, 32, 123)

assert len(train_ds) == 4
assert len(val_ds) == 1
Expand All @@ -30,10 +30,12 @@ def test_load_dataset(tmp_path):

def test_create_augmentation_layer():
data_augmentation = create_augmentation_layer(32, 32)
assert len(data_augmentation.layers) == 3
assert len(data_augmentation.layers) == 5
assert isinstance(data_augmentation.layers[0], tf.keras.layers.RandomFlip)
assert isinstance(data_augmentation.layers[1], tf.keras.layers.RandomRotation)
assert isinstance(data_augmentation.layers[2], tf.keras.layers.RandomZoom)
assert isinstance(data_augmentation.layers[3], tf.keras.layers.RandomContrast)
assert isinstance(data_augmentation.layers[4], tf.keras.layers.GaussianNoise)


def test_get_data_path_addon():
Expand Down
19 changes: 9 additions & 10 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from keras.applications import EfficientNetV2B1
from utilities.tools import get_data_path_addon, get_base_path, suppress_tf_warnings, load_dataset, show_augmented_batch, create_augmentation_layer, plot_model_score, show_sample_batch, show_batch_shape
from utilities.discord_callback import DiscordCallback
from keras.optimizers import Adam
from keras.optimizers import AdamW
from keras.regularizers import l1_l2
from keras.callbacks import EarlyStopping, ModelCheckpoint
import os
import random

import tensorflow as tf
# Ignore warnings
import warnings
Expand All @@ -30,6 +32,8 @@
# Set to True to load trained model
load_model = False
load_path = "../models/all_model_variants/efficientnet-old-head-model-variants.h5"
# Set seed for reproducibility
random_seed = True
# Config
base_path = get_base_path()
path_addon = get_data_path_addon(model_type)
Expand All @@ -38,6 +42,7 @@
"batch_size": 32,
"img_height": img_height,
"img_width": img_width,
"seed": random.randint(0, 1000) if random_seed else 123
}

# Load dataset and classes
Expand All @@ -52,11 +57,6 @@
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Normalize the data
normalization_layer = layers.Rescaling(1. / 255)
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))

# Create data augmentation layer and show augmented batch
data_augmentation = create_augmentation_layer(img_height, img_width)
show_augmented_batch(train_ds, data_augmentation)
Expand Down Expand Up @@ -87,11 +87,11 @@
]) if not load_model else keras.models.load_model(load_path)

# Define optimizer
optimizer = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
optimizer = AdamW(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, use_ema=True)

# Define learning rate scheduler
initial_learning_rate = 0.001
lr_decay_steps = 1000
lr_decay_steps = 10
lr_decay_rate = 0.96
lr_scheduler = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate,
Expand All @@ -101,7 +101,7 @@

# Compile model
model.compile(optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['accuracy'])
model.summary()

Expand Down Expand Up @@ -134,4 +134,3 @@
# Save model
model.save(f"{save_path}{name}.h5")

# TODO: Different data augmentation (vertical, ..), Augmentation before training
26 changes: 7 additions & 19 deletions utilities/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import platform


def load_dataset(path: str, batch_size: int, img_height: int, img_width: int) -> tuple[tf.data.Dataset, tf.data.Dataset, list]:
def load_dataset(path: str, batch_size: int, img_height: int, img_width: int, seed: int) -> tuple[tf.data.Dataset, tf.data.Dataset, list]:
"""
:param path: Path to the Dataset folder
:param batch_size: Integer which defines how many Images are in one Batch
Expand All @@ -20,27 +20,19 @@ def load_dataset(path: str, batch_size: int, img_height: int, img_width: int) ->
:return: Tuple of train, val Dataset and Class names
"""
data_dir = pathlib.Path(path)
# if "more_classes" in path:
# image_count = len(list(data_dir.glob('*/*/*.jpg')))
# else:
# image_count = len(list(data_dir.glob('*/*/*/*.jpg')))

# print("Image count:", image_count)
# cars = list(data_dir.glob('*/*/*/*.jpg'))
# PIL.Image.open(str(cars[0]))
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
seed=seed,
image_size=(img_height, img_width),
batch_size=batch_size)

val_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
seed=seed,
image_size=(img_height, img_width),
batch_size=batch_size)

Expand Down Expand Up @@ -99,12 +91,6 @@ def load_image_subset(path: str, batch_size: int, img_height: int, img_width: in
:return: Subset of Dataset
"""
data_dir = pathlib.Path(path)
# if "more_classes" in path:
# image_count = len(list(data_dir.glob('*/*/*.jpg')))
# else:
# image_count = len(list(data_dir.glob('*/*/*/*.jpg')))

# print("Image count:", image_count)

data = tf.keras.utils.image_dataset_from_directory(
data_dir,
Expand Down Expand Up @@ -168,12 +154,14 @@ def create_augmentation_layer(img_height: int, img_width: int) -> keras.Sequenti
"""
return keras.Sequential(
[
layers.RandomFlip("horizontal",
layers.RandomFlip("vertical",
input_shape=(img_height,
img_width,
3)),
layers.RandomRotation(0.1),
layers.RandomRotation(0.2),
layers.RandomZoom(0.1),
layers.RandomContrast(0.1),
layers.GaussianNoise(0.1)
]
)

Expand Down
Loading