Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FiftyOneTorchDataset #5321

Open
wants to merge 32 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5f19373
first commit
Dec 13, 2024
354c660
made it cleaner
Dec 13, 2024
ee8c483
some docs
Dec 16, 2024
7c4f0e9
some docs
Dec 16, 2024
592dec6
caching values for faster dataloaders
Dec 17, 2024
aa95bc6
changed to torch serialization because it is better if not using fork
Dec 18, 2024
4585d93
typo
Dec 18, 2024
6ae1f27
added note about dataset persistence
Dec 18, 2024
963dc62
DDP
Dec 20, 2024
7d6162c
important docs
Dec 23, 2024
c7f3df8
added top level util function
Dec 23, 2024
dd7480b
optimization - when using database backed, read right from dataset ra…
Dec 24, 2024
0a40b31
added note about changing default start method for processes
Dec 24, 2024
2313c0f
refined notes
Dec 26, 2024
967b84b
Basic example, cached field example, mnist training example
Dec 26, 2024
85fd0e4
removed weights
Dec 26, 2024
b7e3e51
added dataloader for cahed field notebook and some fixes
Dec 26, 2024
fc61ed1
added no_grad to eval
Dec 27, 2024
d1ffeca
added util for authorizing process communcation in case user is on to…
Dec 27, 2024
e08a447
more notes on ddp
Dec 27, 2024
fa8db3d
mnist ddp exmaple
Dec 30, 2024
b7c5230
updated note about serializability
Jan 14, 2025
0a583dd
warnings and more doc cleanup
Jan 14, 2025
011fec9
expanded DDP
Jan 14, 2025
ea49206
replaced prints with logging
Jan 14, 2025
1558ad1
removed redundant if
Jan 15, 2025
6bf3639
renamed cache_fields to cache_field_names
Jan 15, 2025
c03db90
added distributed init function, updated docs accordingly
Jan 15, 2025
0382c80
updated basic example
Jan 15, 2025
a505e87
renamed tutorial for cache_field_names
Jan 15, 2025
8a52fdf
updated cache field tutorial to reflect name change
Jan 15, 2025
4448027
updated ddp example to use distributed_init method and new cache_fiel…
Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
473 changes: 473 additions & 0 deletions docs/source/recipes/torch-dataset-examples/basic_example.ipynb

Large diffs are not rendered by default.

249 changes: 249 additions & 0 deletions docs/source/recipes/torch-dataset-examples/ddp_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
from argparse import ArgumentParser
import os

import fiftyone as fo
from fiftyone.utils.torch import all_gather, local_broadcast_process_authkey

import torch
from tqdm import tqdm
import numpy as np

import utils


def main(local_rank, dataset_name, num_classes, num_epochs, save_dir):

torch.distributed.init_process_group()

# setup local groups
local_group = None
for n in range(
int(
int(os.environ["WORLD_SIZE"]) / int(os.environ["LOCAL_WORLD_SIZE"])
)
):
aux = torch.distributed.new_group()
torch.distributed.barrier()
if int(os.environ["RANK"]) // int(os.environ["LOCAL_WORLD_SIZE"]) == n:
local_group = aux
local_broadcast_process_authkey(local_group)

model = utils.setup_ddp_model(num_classes=num_classes)
model.to(DEVICES[local_rank])
ddp_model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[DEVICES[local_rank]]
)

loss_function = torch.nn.CrossEntropyLoss(reduction="none")

dataset = None
# synchronously load dataset in each trainer
for r in range(int(os.environ["LOCAL_WORLD_SIZE"])):
if local_rank == r:
dataset = fo.load_dataset(dataset_name)
torch.distributed.barrier(local_group)

dataloaders = utils.create_dataloaders_ddp(
dataset,
utils.mnist_get_item,
local_process_group=local_group,
num_workers=4,
batch_size=16,
persistent_workers=True,
)
optimizer = utils.setup_optim(ddp_model)

best_epoch = None
best_loss = np.inf
for epoch in range(num_epochs):
train_epoch(
local_rank,
ddp_model,
dataloaders["train"],
loss_function,
optimizer,
)
validation_loss = validation(
local_rank,
ddp_model,
dataloaders["validation"],
dataset,
loss_function,
)

# average over all trainers
validation_loss = np.mean(all_gather(validation_loss))

if validation_loss < best_loss:
best_loss = validation_loss
best_epoch = epoch
if local_rank == 0:
print(f"New best lost achieved : {best_loss}. Saving model...")
torch.save(model.state_dict(), f"{save_dir}/epoch_{epoch}.pt")

jacobsela marked this conversation as resolved.
Show resolved Hide resolved
torch.distributed.barrier()

model = utils.setup_ddp_model(
num_classes=num_classes,
weights_path=f"{save_dir}/epoch_{best_epoch}.pt",
).to(DEVICES[local_rank])
model.to(DEVICES[local_rank])
ddp_model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[DEVICES[local_rank]]
)
test_loss = validation(
local_rank,
ddp_model,
dataloaders["test"],
dataset,
loss_function,
save_results=True,
)
test_loss = np.mean(all_gather(test_loss))
classes = [
utils.mnist_index_to_label_string(i) for i in range(num_classes)
]
if local_rank == 0:
results = dataset.match_tags("test").evaluate_classifications(
"predictions",
gt_field="ground_truth",
eval_key="eval",
classes=classes,
k=3,
)

print("Final Test Results:")
print(f"Loss = {test_loss}")
results.print_report(classes=classes)

torch.distributed.destroy_process_group(torch.distributed.group.WORLD)


def train_epoch(local_rank, model, dataloader, loss_function, optimizer):
model.train()

cummulative_loss = 0
pbar = (
tqdm(enumerate(dataloader), total=len(dataloader))
if local_rank == 0
else enumerate(dataloader)
)
for batch_num, batch in pbar:
batch["image"] = batch["image"].to(DEVICES[local_rank])
batch["label"] = batch["label"].to(DEVICES[local_rank])

prediction = model(batch["image"])
loss = torch.mean(loss_function(prediction, batch["label"]))

loss.backward()
optimizer.step()
optimizer.zero_grad()

cummulative_loss = cummulative_loss + loss.detach().cpu().numpy()
if local_rank == 0:
if batch_num % 100 == 0:
pbar.set_description(
f"Average Train Loss = {cummulative_loss / (batch_num + 1):10f}"
)
return cummulative_loss / (batch_num + 1)


@torch.no_grad()
def validation(
local_rank, model, dataloader, dataset, loss_function, save_results=False
):
model.eval()

cummulative_loss = 0
pbar = (
tqdm(enumerate(dataloader), total=len(dataloader))
if local_rank == 0
else enumerate(dataloader)
)
for batch_num, batch in pbar:
with torch.no_grad():
batch["image"] = batch["image"].to(DEVICES[local_rank])
batch["label"] = batch["label"].to(DEVICES[local_rank])

prediction = model(batch["image"])
loss_individual = (
loss_function(prediction, batch["label"])
.detach()
.cpu()
.numpy()
)
jacobsela marked this conversation as resolved.
Show resolved Hide resolved

if save_results:
samples = dataset._dataset.select(batch["id"])
samples.set_values("loss", loss_individual.tolist())

fo_predictions = [
fo.Classification(
label=utils.mnist_index_to_label_string(
np.argmax(sample_logits)
),
logits=sample_logits,
)
for sample_logits in prediction.detach().cpu().numpy()
]
samples.set_values("predictions", fo_predictions)
samples.save()

cummulative_loss = cummulative_loss + np.mean(loss_individual)
if local_rank == 0:
if batch_num % 100 == 0:
pbar.set_description(
f"Average Validation Loss = {cummulative_loss / (batch_num + 1):10f}"
)
return cummulative_loss / (batch_num + 1)


if __name__ == "__main__":

# run with
# torchrun --nnodes=1 --nproc-per-node=6 \
# PATH/TO/YOUR/ddp_train.py -d mnist -n 10 -e 3 \
# -s /PATH/TO/SAVE/WEIGHTS --devices 2 3 4 5 6 7

argparser = ArgumentParser()
argparser.add_argument(
"-d", "--dataset", type=str, help="name of fiftyone dataset"
)
argparser.add_argument(
"-n",
"--num_classes",
type=int,
help="number of classes in the dataset",
)
argparser.add_argument(
"-e",
"--epochs",
type=int,
help="number of epochs to train for",
default=5,
)
argparser.add_argument(
"-s",
"--save_dir",
type=str,
help="directory to save checkpoints to",
default="~/mnist_weights",
)
jacobsela marked this conversation as resolved.
Show resolved Hide resolved
argparser.add_argument(
"--devices", default=range(torch.cuda.device_count()), nargs="*"
)

args = argparser.parse_args()

assert int(os.environ["LOCAL_WORLD_SIZE"]) == len(args.devices)

DEVICES = [torch.device(f"cuda:{d}") for d in args.devices]

local_rank = int(os.environ["LOCAL_RANK"])

torch.multiprocessing.set_start_method("forkserver")
torch.multiprocessing.set_forkserver_preload(["torch", "fiftyone"])

main(
local_rank, args.dataset, args.num_classes, args.epochs, args.save_dir
)
130 changes: 130 additions & 0 deletions docs/source/recipes/torch-dataset-examples/mnist_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import fiftyone as fo

import torch
from tqdm import tqdm
import numpy as np

import utils

DEVICE = torch.device("cuda:0")


def main(dataset, num_classes, num_epochs, device, save_dir):
global DEVICE
DEVICE = torch.device(device)
model = utils.setup_model(num_classes).to(DEVICE)
loss_function = torch.nn.CrossEntropyLoss(reduction="none")
dataloaders = utils.create_dataloaders(
dataset,
utils.mnist_get_item,
num_workers=4,
batch_size=16,
persistent_workers=True,
)
optimizer = utils.setup_optim(model)

best_epoch = None
best_loss = np.inf
for epoch in range(num_epochs):
train_epoch(model, dataloaders["train"], loss_function, optimizer)
validation_loss = validation(
model, dataloaders["validation"], dataset, loss_function
)

if validation_loss < best_loss:
best_loss = validation_loss
print(f"New best lost achieved : {best_loss}. Saving model...")
best_epoch = epoch
torch.save(model.state_dict(), f"{save_dir}/epoch_{epoch}.pt")

model = utils.setup_model(
num_classes, f"{save_dir}/epoch_{best_epoch}.pt"
).to(DEVICE)
test_loss = validation(
model, dataloaders["test"], dataset, loss_function, save_results=True
)
classes = [
utils.mnist_index_to_label_string(i) for i in range(num_classes)
]
results = dataset.match_tags("test").evaluate_classifications(
"predictions",
gt_field="ground_truth",
eval_key="eval",
classes=classes,
k=3,
)

print("Final Test Results:")
print(f"Loss = {test_loss}")
results.print_report(classes=classes)


def train_epoch(model, dataloader, loss_function, optimizer):
model.train()

cummulative_loss = 0
pbar = tqdm(enumerate(dataloader), total=len(dataloader))
for batch_num, batch in pbar:
batch["image"] = batch["image"].to(DEVICE)
batch["label"] = batch["label"].to(DEVICE)

prediction = model(batch["image"])
loss = torch.mean(loss_function(prediction, batch["label"]))

loss.backward()
optimizer.step()
optimizer.zero_grad()

cummulative_loss = cummulative_loss + loss.detach().cpu().numpy()
if batch_num % 100 == 0:
pbar.set_description(
f"Average Train Loss = {cummulative_loss / (batch_num + 1):10f}"
)
return cummulative_loss / (batch_num + 1)


@torch.no_grad()
def validation(model, dataloader, dataset, loss_function, save_results=False):
model.eval()

cummulative_loss = 0
pbar = tqdm(enumerate(dataloader), total=len(dataloader))
for batch_num, batch in pbar:
with torch.no_grad():
batch["image"] = batch["image"].to(DEVICE)
batch["label"] = batch["label"].to(DEVICE)

prediction = model(batch["image"])
loss_individual = (
loss_function(prediction, batch["label"])
.detach()
.cpu()
.numpy()
)

if save_results:
samples = dataset._dataset.select(batch["id"])
samples.set_values("loss", loss_individual.tolist())

fo_predictions = [
fo.Classification(
label=utils.mnist_index_to_label_string(
np.argmax(sample_logits)
),
logits=sample_logits,
)
for sample_logits in prediction.detach().cpu().numpy()
]
samples.set_values("predictions", fo_predictions)
samples.save()

cummulative_loss = cummulative_loss + np.mean(loss_individual)
if batch_num % 100 == 0:
pbar.set_description(
f"Average Validation Loss = {cummulative_loss / (batch_num + 1):10f}"
)
return cummulative_loss / (batch_num + 1)


if __name__ == "__main__":
pass
Loading
Loading