Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds video GAN framework #29

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__pycache__
*.py[cod]

assets

*.pth
*.pb
*.pkl
2 changes: 1 addition & 1 deletion Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY requirements.txt .

RUN python3 -m venv /opt/venv && \
python3 -m pip install pip==19.2.3 pip-tools==4.0.0
python3 -m pip install pip==19.2.3 pip-tools==4.0.0 setuptools==41.4.0

RUN echo "https://download.pytorch.org/whl/cpu/torch-1.3.0%2Bcpu-cp36-cp36m-linux_x86_64.whl \
--hash=sha256:ce648bb0c6b86dd99a8b5598ae6362a066cca8de69ad089cd206ace3bdec0a5f \
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY requirements.txt .

RUN python3 -m venv /opt/venv && \
python3 -m pip install pip==19.2.3 pip-tools==4.0.0
python3 -m pip install pip==19.2.3 pip-tools==4.0.0 setuptools==41.4.0

RUN echo "https://download.pytorch.org/whl/cu100/torch-1.3.0%2Bcu100-cp36-cp36m-linux_x86_64.whl \
--hash=sha256:2414744c5f9fc25e4ee181019df188b0ea28c7866ce7af13116c4d7e538460b7 \
Expand Down
101 changes: 101 additions & 0 deletions ig65m/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import torch.nn as nn

from einops import rearrange


# Attention: start with this paper
# https://arxiv.org/abs/1904.11492


class SelfAttention3d(nn.Module):
def __init__(self, planes):
super().__init__()

# Note: ratios below should be made configurable

self.q = nn.Conv3d(planes, planes // 8, kernel_size=1, bias=False)
self.k = nn.Conv3d(planes, planes // 8, kernel_size=1, bias=False)
self.v = nn.Conv3d(planes, planes // 2, kernel_size=1, bias=False)
self.z = nn.Conv3d(planes // 2, planes, kernel_size=1, bias=False)

self.y = nn.Parameter(torch.tensor(0.))

def forward(self, x):
q = self.q(x)
k = self.k(x)
v = self.v(x)

# Note: pooling below should be made configurable

k = nn.functional.max_pool3d(k, (2, 2, 2))
v = nn.functional.max_pool3d(v, (2, 2, 2))

q = rearrange(q, "n c t h w -> n (t h w) c")
k = rearrange(k, "n c t h w -> n c (t h w)")
v = rearrange(v, "n c t h w -> n c (t h w)")

beta = torch.bmm(q, k)
beta = torch.softmax(beta, dim=-1)
beta = rearrange(beta, "n thw c -> n c thw")

att = torch.bmm(v, beta)
att = rearrange(att, "n c (t h w) -> n c t h w",
t=x.size(2), h=x.size(3), w=x.size(4))

return self.y * self.z(att) + x


class SimpleSelfAttention3d(nn.Module):
def __init__(self, planes):
super().__init__()

self.k = nn.Conv3d(planes, 1, kernel_size=1, bias=False)
self.v = nn.Conv3d(planes, planes, kernel_size=1, bias=False)

self.y = nn.Parameter(torch.tensor(0.))

def forward(self, x):
k = self.k(x)
k = rearrange(k, "n c t h w -> n (t h w) c")
k = torch.softmax(k, dim=-1)

xx = rearrange(x, "n c t h w -> n c (t h w)")

ctx = torch.bmm(xx, k)
ctx = rearrange(ctx, "n c () -> n c () () ()")

att = self.v(ctx)

return self.y * att + x


class GlobalContext3d(nn.Module):
def __init__(self, planes):
super().__init__()

self.k = nn.Conv3d(planes, 1, kernel_size=1, bias=False)

# Note: ratios below should be made configurable

self.v = nn.Sequential(
nn.Conv3d(planes, planes // 8, kernel_size=1, bias=False),
nn.LayerNorm((planes // 8, 1, 1, 1)),
nn.ReLU(inplace=True),
nn.Conv3d(planes // 8, planes, kernel_size=1, bias=False))

self.y = nn.Parameter(torch.tensor(0.))

def forward(self, x):
k = self.k(x)
k = rearrange(k, "n c t h w -> n (t h w) c")
k = torch.softmax(k, dim=-1)

xx = rearrange(x, "n c t h w -> n c (t h w)")

ctx = torch.bmm(xx, k)
ctx = rearrange(ctx, "n c () -> n c () () ()")

att = self.v(ctx)

return self.y * att + x
13 changes: 13 additions & 0 deletions ig65m/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ig65m.cli.extract
import ig65m.cli.semcode
import ig65m.cli.dreamer
import ig65m.cli.vgan


parser = argparse.ArgumentParser(prog="ig65m")
Expand Down Expand Up @@ -50,5 +51,17 @@
dreamer.set_defaults(main=ig65m.cli.dreamer.main)


vgan = subcmd.add_parser("vgan", help="🥑 video generative adversarial network", formatter_class=Formatter)
vgan.add_argument("videos", type=Path, help="directory to read videos from")
vgan.add_argument("--checkpoints", type=Path, required=True, help="directory to save checkpoints to")
vgan.add_argument("--num-epochs", type=int, default=100, help="number of epochs to run through dataset")
vgan.add_argument("--batch-size", type=int, default=1, help="number of clips per batch")
vgan.add_argument("--clip-length", type=int, default=32, help="number of frames per clip")
vgan.add_argument("--z-dimension", type=int, default=128, help="noise dimensionality")
vgan.add_argument("--save-frequency", type=int, default=100, help="number of steps to checkpoint after")
vgan.add_argument("--logs", type=Path, required=True, help="directory to save TensorBoard logs to")
vgan.set_defaults(main=ig65m.cli.vgan.main)


args = parser.parse_args()
args.main(args)
161 changes: 161 additions & 0 deletions ig65m/cli/vgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import sys

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from torchvision.transforms import Compose

from einops import rearrange
from einops.layers.torch import Rearrange

from ig65m.datasets import VideoDirectoryDataset
from ig65m.transforms import ToTensor, Resize, CenterCrop, Normalize, Denormalize
from ig65m.losses import GeneratorHingeLoss, DiscriminatorHingeLoss
from ig65m.gan import Generator, Discriminator


def main(args):
if torch.cuda.is_available():
print("🐎 Running on GPU(s)", file=sys.stderr)
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True
else:
print("🐌 Running on CPU(s)", file=sys.stderr)
device = torch.device("cpu")

mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]

transform = Compose([
ToTensor(),
Rearrange("t h w c -> c t h w"),
#Resize(48),
CenterCrop(32),
Normalize(mean=mean, std=std),
])

denormalize = Denormalize(mean=mean, std=std)

dataset = VideoDirectoryDataset(args.videos, clip_length=args.clip_length, transform=transform)
loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0)

args.checkpoints.mkdir(exist_ok=True)

g = Generator(args.z_dimension)
g = g.to(device)
g = nn.DataParallel(g)

d = Discriminator()
d = d.to(device)
d = nn.DataParallel(d)

lr_g = 1e-4 * 1
lr_d = 1e-4 * 4

opt_g = torch.optim.Adam([p for p in g.parameters() if p.requires_grad],
lr=lr_g, betas=(0, 0.9))

opt_d = torch.optim.Adam([p for p in d.parameters() if p.requires_grad],
lr=lr_d, betas=(0, 0.9))

crit_g = GeneratorHingeLoss()
crit_d = DiscriminatorHingeLoss()

zfix = torch.randn(1, args.z_dimension, device=device).clamp_(0)

step = 0

with SummaryWriter(str(args.logs)) as summary:
for _ in range(args.num_epochs):
for inputs in loader:
adjust_learning_rate(opt_g, step, lr_g)
adjust_learning_rate(opt_d, step, lr_d)

# Step D

g.zero_grad()
d.zero_grad()

z = torch.randn(inputs.size(0), args.z_dimension, device=device).clamp_(0)

real_data = inputs.to(device)
fake_data = g(z)

real_out = d(real_data)
fake_out = d(fake_data)

loss_d_real, loss_d_fake = crit_d(real_out, fake_out)
loss_d = loss_d_real.mean() + loss_d_fake.mean()
loss_d.backward()

opt_d.step()

# Step G

g.zero_grad()
d.zero_grad()

z = torch.randn(inputs.size(0), args.z_dimension, device=device).clamp_(0)

fake_data = g(z)
fake_out = d(fake_data)

loss_g = crit_g(fake_out)
loss_g.backward()

opt_g.step()

# Done

summary.add_scalar("Loss/Discriminator/Real", loss_d_real.item(), step)
summary.add_scalar("Loss/Discriminator/Fake", loss_d_fake.item(), step)
summary.add_scalar("Loss/Generator", loss_g.item(), step)

if step % args.save_frequency == 0:
real_data = inputs
real_clip = denormalize(real_data[0])
real_images = rearrange(real_clip, "c t h w -> t c h w")

summary.add_images("Images/Real", real_images, step)

with torch.no_grad():
for m in g.modules():
if isinstance(m, nn.BatchNorm3d):
m.eval()

fake_data = g(zfix)

for m in g.modules():
if isinstance(m, nn.BatchNorm3d):
m.train()

fake_clip = denormalize(fake_data[0])
fake_images = rearrange(fake_clip, "c t h w -> t c h w")

summary.add_images("Images/Fake", fake_images, step)

state = {"step": step,
"g_state_dict": g.state_dict(), "d_state_dict": d.state_dict(),
"g_opt": opt_g.state_dict(), "d_opt": opt_d.state_dict()}

torch.save(state, args.checkpoints / "state-{:010d}.pth".format(step))

step += 1

print("🥑 Done", file=sys.stderr)


# https://arxiv.org/abs/1706.02677
def adjust_learning_rate(optimizer, step, lr):
warmup = 1000
base = 0.01 * lr

def lerp(c, first, last):
return first + c * (last - first)

if step <= warmup:
lr = lerp(step / warmup, base, lr)

for param_group in optimizer.param_groups:
param_group["lr"] = lr
Loading