From 87ea96ef803d768182eadbb64f44ee34a2fee776 Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Sun, 11 Feb 2024 17:53:12 +0100 Subject: [PATCH 1/2] Feature: Added Usage of class weight calculation --- training/train.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/training/train.py b/training/train.py index 6212471..59f9a56 100644 --- a/training/train.py +++ b/training/train.py @@ -4,7 +4,8 @@ from keras import layers from keras.models import Sequential from keras.applications import EfficientNetV2B1 -from utilities.tools import get_data_path_addon, get_base_path, suppress_tf_warnings, load_dataset, show_augmented_batch, create_augmentation_layer, plot_model_score, show_sample_batch, show_batch_shape +from utilities.tools import get_data_path_addon, get_base_path, suppress_tf_warnings, load_dataset, show_augmented_batch, create_augmentation_layer, plot_model_score, show_sample_batch,\ + show_batch_shape, compute_class_weights from utilities.discord_callback import DiscordCallback from keras.optimizers import AdamW from keras.regularizers import l1_l2 @@ -57,6 +58,9 @@ train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE) val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE) +# Compute class weights to balance the data +class_weights = compute_class_weights(class_names, train_ds, val_ds) + # Create data augmentation layer and show augmented batch data_augmentation = create_augmentation_layer(img_height, img_width) show_augmented_batch(train_ds, data_augmentation) @@ -126,7 +130,8 @@ train_ds, validation_data=val_ds, epochs=epochs, - callbacks=[lr, early_stopping, model_checkpoint, discord_callback] + callbacks=[lr, early_stopping, model_checkpoint, discord_callback], + class_weight=class_weights ) # Plot and save model score plot_model_score(history, name, model_type) From 6878af608df4790973a31bc3f2b5a4c2f71fb341 Mon Sep 17 00:00:00 2001 From: Flippchen <91947480+Flippchen@users.noreply.github.com> Date: Sun, 11 Feb 2024 17:53:25 +0100 Subject: [PATCH 2/2] Feature: Added logic behind class weight calculation --- utilities/tools.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/utilities/tools.py b/utilities/tools.py index 3fa8af0..4cc8385 100644 --- a/utilities/tools.py +++ b/utilities/tools.py @@ -9,6 +9,7 @@ import os import logging import platform +from sklearn.utils.class_weight import compute_class_weight def load_dataset(path: str, batch_size: int, img_height: int, img_width: int, seed: int) -> tuple[tf.data.Dataset, tf.data.Dataset, list]: @@ -184,6 +185,47 @@ def show_augmented_batch(train_ds, data_augmentation) -> None: plt.show() +def compute_class_weights(class_names: list, dataset_train: tf.data.Dataset, dataset_val: tf.data.Dataset) -> dict: + """ + Computes the class weights for the dataset + + :param class_names: List of class names + :param dataset_train: Train-Dataset to compute the class weights for + :param dataset_val: Validation-Dataset to compute the class weights for + :return: Dictionary with class weights + """ + class_counts = {class_name: 0 for class_name in class_names} + + for images, label in dataset_train.unbatch(): # Iterate over each instance + class_name = class_names[label.numpy()] # Directly use label to get class name + class_counts[class_name] += 1 + + class_count_validation = {class_name: 0 for class_name in class_names} + + for images, label in dataset_val.unbatch(): # Iterate over each instance + class_name = class_names[label.numpy()] # Directly use label to get class name + class_count_validation[class_name] += 1 + + print("Validation Weights:", {class_name: count for class_name, count in class_count_validation.items()}) + print("Train Weights:", {class_name: count for class_name, count in class_counts.items()}) + + # Convert class counts to a list in the order of class names + class_samples = np.array([class_counts[class_name] for class_name in class_names]) + + # Calculate class weights + # This requires the classes to be sequential numbers starting from 0, which they typically are if indexed by class_names + class_weights = compute_class_weight( + class_weight='balanced', + classes=np.arange(len(class_names)), + y=np.concatenate([np.full(count, i) for i, count in enumerate(class_samples)]) + ) + + # Convert class weights to a dictionary where keys are numerical class indices + class_weight_dict = {i: weight for i, weight in enumerate(class_weights)} + + return class_weight_dict + + def plot_model_score(history, name: str, model_type: str) -> None: """ Plots the accuracy and loss of the model