Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Too much RAM usage by ImageClassificationData #1450

Open
ethanwharris opened this issue Sep 5, 2022 Discussed in #1442 · 4 comments
Open

Too much RAM usage by ImageClassificationData #1450

ethanwharris opened this issue Sep 5, 2022 Discussed in #1442 · 4 comments
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Milestone

Comments

@ethanwharris
Copy link
Collaborator

ethanwharris commented Sep 5, 2022

Discussed in #1442

Originally posted by Hravan September 1, 2022
I'm setting up a training for this kaggle competition dataset: https://www.kaggle.com/competitions/plant-pathology-2021-fgvc8
(I'm using here only samples with single labels to make the problem simpler)

The problem is that the ImageClassificationData takes too much RAM and GPU is underutilized. I wrote the code in plain PyTorch for comparison to confirm that the problem is somewhere within ImageClassificationData.

Code shared by both training versions:

import pandas as pd
from skimage import io
from sklearn.preprocessing import OneHotEncoder
import torch
from torch.utils.data import Dataset
from torchvision import transforms as T


class PlantDataset(Dataset):
    def __init__(self, df, transform=None) -> None:
        super().__init__()
        self.img_paths = df["image"].tolist()
        self.transform = transform
        self.encoder = OneHotEncoder()
        self.labels = (
            self.encoder.fit_transform(df["label"].values.reshape(-1, 1))
            .todense()
            .A
        )

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = io.imread(self.img_paths[idx])
        if self.transform is not None:
            img = self.transform(img)
        label = self.labels[idx]
        # return {
        #    "input": img,
        #    "target": torch.tensor(label, dtype=torch.uint8),
        # }
        return img, torch.tensor(label, dtype=torch.float32)


def preprocess_df(csv_path, images_root):
    df = pd.read_csv(csv_path)
    df = df[~df["labels"].str.contains(" ")]
    df["image"] = images_root + df["image"]
    df = df.rename(columns={"labels": "label"})
    return df


def split_df(df, train_pct):
    df = df.sample(frac=1)
    n_train = int(train_pct * len(df))
    train_df = df.iloc[:n_train].reset_index()
    val_df = df.iloc[n_train:].reset_index()
    return train_df, val_df


def create_dataloader(df):
    train_compose = T.Compose(
        [
            T.ToPILImage(),
            T.Resize((224, 224)),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
    dataloader = torch.utils.data.DataLoader(
        PlantDataset(df, transform=train_compose),
        batch_size=32,
        num_workers=8,
        prefetch_factor=8,
    )
    return dataloader

Training in plain PyTorch:

def train(model, data_loader, n_epochs):
    model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters())
    loss_fn = torch.nn.CrossEntropyLoss()

    for i in range(n_epochs):
        for images, labels in tqdm.tqdm(data_loader):
            images = images.cuda()
            preds = model(images)
            loss = loss_fn(preds, labels.cuda())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(f"End of epoch {i}")


def main():
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument("csv_path")
    arg_parser.add_argument("images_root")
    args = arg_parser.parse_args()

    model = torchvision.models.resnet18()
    model.fc = torch.nn.Linear(512, 6)

    df = preprocess_df(args.csv_path, args.images_root)
    train_df, val_df = split_df(df, 0.1)
    train_loader = create_dataloader(train_df)
    time0 = perf_counter()
    train(model, train_loader, 2)
    print(f"Time elapsed: {perf_counter() - time0}")


if __name__ == "__main__":
    main()

Training in Lightning Flash:

class Resnet18(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.resnet18()
        self.model.fc = torch.nn.Linear(512, 6)
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def training_step(self, batch, batch_idx):
        x, y = batch["input"], batch["target"]
        y_hat = self.model(x)
        loss = self.loss_fn(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters())


def main():
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument("csv_path")
    arg_parser.add_argument("images_root")
    args = arg_parser.parse_args()

    model = Resnet18()
    df = preprocess_df(args.csv_path, args.images_root)
    train_df, val_df = split_df(df, 0.1)
    datamodule = ImageClassificationData.from_data_frame(
        "image",
        "label",
        train_data_frame=train_df,
        batch_size=32,
        transform_kwargs=dict(image_size=(224, 224)),
        num_workers=8,
        persistent_workers=True,
        pin_memory=False,
    )

    time0 = perf_counter()
    trainer = flash.Trainer(max_epochs=2, gpus=torch.cuda.device_count())
    trainer.fit(model, datamodule=datamodule)
    print(f"Time elapsed: {perf_counter() - time0}")


if __name__ == "__main__":
    main()

When I increase bach_size to 64 or num_workers to 16 in ImageClassificationData, I start having problems with RAM, which does not happen for the plain PyTorch version. Any ideas what might be the problem? I tried profiling, but didn't get to any sensible conclusion, except that I bet the problem is in BaseDataFetcher in DataModule.

@ethanwharris ethanwharris added bug / fix Something isn't working help wanted Extra attention is needed labels Sep 5, 2022
@ethanwharris ethanwharris added this to the 0.8.x milestone Sep 5, 2022
@Atharva-Phatak
Copy link

@ethanwharris, I can take a look if this is open. Seems interesting that there is such a bottleneck. Could you give me a bit more details ?

Maybe we can test this on a smaller dataset like CIFAR and see if that's the case.

@ethanwharris
Copy link
Collaborator Author

Hey @Atharva-Phatak thanks for the offer! Please feel free to take a look 😃 I think a great starting point would be to have a model in Flash (trained on e.g. CIFAR-10 as you suggested) and the equivalent model just using Lightning to see if the maximum batch size you can get is different on each. If it is different then that would confirm we have a leak

@Atharva-Phatak
Copy link

@ethanwharris Sorry, I was busy with college and working on a PR for bolts. I will look at this week and let's where we can go from here :)

@Borda
Copy link
Member

Borda commented Dec 5, 2022

@Atharva-Phatak that would be great is you can still have look at it... 🐰

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants