Skip to content

Commit

Permalink
Merge pull request #94 from Flippchen/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
Flippchen authored Feb 11, 2024
2 parents 60f7f61 + 6878af6 commit 2e7955c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
9 changes: 7 additions & 2 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from keras import layers
from keras.models import Sequential
from keras.applications import EfficientNetV2B1
from utilities.tools import get_data_path_addon, get_base_path, suppress_tf_warnings, load_dataset, show_augmented_batch, create_augmentation_layer, plot_model_score, show_sample_batch, show_batch_shape
from utilities.tools import get_data_path_addon, get_base_path, suppress_tf_warnings, load_dataset, show_augmented_batch, create_augmentation_layer, plot_model_score, show_sample_batch,\
show_batch_shape, compute_class_weights
from utilities.discord_callback import DiscordCallback
from keras.optimizers import AdamW
from keras.regularizers import l1_l2
Expand Down Expand Up @@ -57,6 +58,9 @@
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Compute class weights to balance the data
class_weights = compute_class_weights(class_names, train_ds, val_ds)

# Create data augmentation layer and show augmented batch
data_augmentation = create_augmentation_layer(img_height, img_width)
show_augmented_batch(train_ds, data_augmentation)
Expand Down Expand Up @@ -126,7 +130,8 @@
train_ds,
validation_data=val_ds,
epochs=epochs,
callbacks=[lr, early_stopping, model_checkpoint, discord_callback]
callbacks=[lr, early_stopping, model_checkpoint, discord_callback],
class_weight=class_weights
)
# Plot and save model score
plot_model_score(history, name, model_type)
Expand Down
42 changes: 42 additions & 0 deletions utilities/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
import logging
import platform
from sklearn.utils.class_weight import compute_class_weight


def load_dataset(path: str, batch_size: int, img_height: int, img_width: int, seed: int) -> tuple[tf.data.Dataset, tf.data.Dataset, list]:
Expand Down Expand Up @@ -184,6 +185,47 @@ def show_augmented_batch(train_ds, data_augmentation) -> None:
plt.show()


def compute_class_weights(class_names: list, dataset_train: tf.data.Dataset, dataset_val: tf.data.Dataset) -> dict:
"""
Computes the class weights for the dataset
:param class_names: List of class names
:param dataset_train: Train-Dataset to compute the class weights for
:param dataset_val: Validation-Dataset to compute the class weights for
:return: Dictionary with class weights
"""
class_counts = {class_name: 0 for class_name in class_names}

for images, label in dataset_train.unbatch(): # Iterate over each instance
class_name = class_names[label.numpy()] # Directly use label to get class name
class_counts[class_name] += 1

class_count_validation = {class_name: 0 for class_name in class_names}

for images, label in dataset_val.unbatch(): # Iterate over each instance
class_name = class_names[label.numpy()] # Directly use label to get class name
class_count_validation[class_name] += 1

print("Validation Weights:", {class_name: count for class_name, count in class_count_validation.items()})
print("Train Weights:", {class_name: count for class_name, count in class_counts.items()})

# Convert class counts to a list in the order of class names
class_samples = np.array([class_counts[class_name] for class_name in class_names])

# Calculate class weights
# This requires the classes to be sequential numbers starting from 0, which they typically are if indexed by class_names
class_weights = compute_class_weight(
class_weight='balanced',
classes=np.arange(len(class_names)),
y=np.concatenate([np.full(count, i) for i, count in enumerate(class_samples)])
)

# Convert class weights to a dictionary where keys are numerical class indices
class_weight_dict = {i: weight for i, weight in enumerate(class_weights)}

return class_weight_dict


def plot_model_score(history, name: str, model_type: str) -> None:
"""
Plots the accuracy and loss of the model
Expand Down

0 comments on commit 2e7955c

Please sign in to comment.