This repository contains a multi-task deep learning model for COVID-19 classification, lung cancer detection, and lung segmentation. The model is based on the ResNet-50 backbone and utilizes convolutional neural networks (CNNs) for the individual tasks. In a multi-task learning setting, the significance of jointly training these tasks is to improve generalizability, learn efficient feature extraction, and effective usage of lesser data.
- Python 3.9
- PyTorch 1.12
- torchvision 0.13
- Other dependencies (specified in environment/environment.yaml)
- Clone the repository:
git clone https://github.com/AmiteshBadkul/LCFNN.git
cd LCFNN/environment/
- Create & activate the conda environment:
conda env create -f environment.yaml
conda activate multi-task-learning
- Prepare the dataset:
- COVID-19 Classification: Place the dataset in the
classification_dataset
directory. Here is the link to the dataset --> COVID19 - Lung Cancer Detection: Place the dataset in the
cancer_detection_dataset
directory. Here is the link to the dataset --> Lung Cancer Detection - Lung Segmentation: Place the dataset in the
segmentation_dataset
directory. Here is the link to the dataset --> Lung Segmentation
- Train the model: The hyperparameters can be modified through CLI.
python main.py
The project has the following structure:
code/
- Contains the code for the model for training and evaluation.results/
- Contains the results of the trained models as well as the model.analysis/
- Contains jupyter notebooks for analysis of the results obtained.
The results currently obtained are baseline results more model improvements will improve the performance further.
The model achieves the following performance on the test set:
-
COVID-19 Classification:
- Accuracy: 81.05%
- F1 Score: 80.79%
-
Lung Cancer Detection:
- Accuracy: 52.55%
- F1 Score: 50.42%
-
Lung Segmentation:
- IoU: 0.26
- Dice Coefficient: NA
Some notes and to-do list:
Effective and correct implementation of IoU and Dice Coefficient.- Weight balancing techniques.
- The loss function for Segmentation task oscillates which may indicate it's stuck in a local minima.