Skip to content

Commit

Permalink
Adds video GAN framework
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-j-h committed Oct 27, 2019
1 parent 6f426a6 commit 6ac0b2b
Show file tree
Hide file tree
Showing 13 changed files with 708 additions and 55 deletions.
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__pycache__
*.py[cod]

assets

*.pth
*.pb
*.pkl
2 changes: 1 addition & 1 deletion Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY requirements.txt .

RUN python3 -m venv /opt/venv && \
python3 -m pip install pip==19.2.3 pip-tools==4.0.0
python3 -m pip install pip==19.2.3 pip-tools==4.0.0 setuptools==41.4.0

RUN echo "https://download.pytorch.org/whl/cpu/torch-1.3.0%2Bcpu-cp36-cp36m-linux_x86_64.whl \
--hash=sha256:ce648bb0c6b86dd99a8b5598ae6362a066cca8de69ad089cd206ace3bdec0a5f \
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY requirements.txt .

RUN python3 -m venv /opt/venv && \
python3 -m pip install pip==19.2.3 pip-tools==4.0.0
python3 -m pip install pip==19.2.3 pip-tools==4.0.0 setuptools==41.4.0

RUN echo "https://download.pytorch.org/whl/cu100/torch-1.3.0%2Bcu100-cp36-cp36m-linux_x86_64.whl \
--hash=sha256:2414744c5f9fc25e4ee181019df188b0ea28c7866ce7af13116c4d7e538460b7 \
Expand Down
101 changes: 101 additions & 0 deletions ig65m/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import torch.nn as nn

from einops import rearrange


# Attention: start with this paper
# https://arxiv.org/abs/1904.11492


class SelfAttention3d(nn.Module):
def __init__(self, planes):
super().__init__()

# Note: ratios below should be made configurable

self.q = nn.Conv3d(planes, planes // 8, kernel_size=1, bias=False)
self.k = nn.Conv3d(planes, planes // 8, kernel_size=1, bias=False)
self.v = nn.Conv3d(planes, planes // 2, kernel_size=1, bias=False)
self.z = nn.Conv3d(planes // 2, planes, kernel_size=1, bias=False)

self.y = nn.Parameter(torch.tensor(0.))

def forward(self, x):
q = self.q(x)
k = self.k(x)
v = self.v(x)

# Note: pooling below should be made configurable

k = nn.functional.max_pool3d(k, (2, 2, 2))
v = nn.functional.max_pool3d(v, (2, 2, 2))

q = rearrange(q, "n c t h w -> n (t h w) c")
k = rearrange(k, "n c t h w -> n c (t h w)")
v = rearrange(v, "n c t h w -> n c (t h w)")

beta = torch.bmm(q, k)
beta = torch.softmax(beta, dim=-1)
beta = rearrange(beta, "n thw c -> n c thw")

att = torch.bmm(v, beta)
att = rearrange(att, "n c (t h w) -> n c t h w",
t=x.size(2), h=x.size(3), w=x.size(4))

return self.y * self.z(att) + x


class SimpleSelfAttention3d(nn.Module):
def __init__(self, planes):
super().__init__()

self.k = nn.Conv3d(planes, 1, kernel_size=1, bias=False)
self.v = nn.Conv3d(planes, planes, kernel_size=1, bias=False)

self.y = nn.Parameter(torch.tensor(0.))

def forward(self, x):
k = self.k(x)
k = rearrange(k, "n c t h w -> n (t h w) c")
k = torch.softmax(k, dim=-1)

xx = rearrange(x, "n c t h w -> n c (t h w)")

ctx = torch.bmm(xx, k)
ctx = rearrange(ctx, "n c () -> n c () () ()")

att = self.v(ctx)

return self.y * att + x


class GlobalContext3d(nn.Module):
def __init__(self, planes):
super().__init__()

self.k = nn.Conv3d(planes, 1, kernel_size=1, bias=False)

# Note: ratios below should be made configurable

self.v = nn.Sequential(
nn.Conv3d(planes, planes // 8, kernel_size=1, bias=False),
nn.LayerNorm((planes // 8, 1, 1, 1)),
nn.ReLU(inplace=True),
nn.Conv3d(planes // 8, planes, kernel_size=1, bias=False))

self.y = nn.Parameter(torch.tensor(0.))

def forward(self, x):
k = self.k(x)
k = rearrange(k, "n c t h w -> n (t h w) c")
k = torch.softmax(k, dim=-1)

xx = rearrange(x, "n c t h w -> n c (t h w)")

ctx = torch.bmm(xx, k)
ctx = rearrange(ctx, "n c () -> n c () () ()")

att = self.v(ctx)

return self.y * att + x
13 changes: 13 additions & 0 deletions ig65m/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ig65m.cli.extract
import ig65m.cli.semcode
import ig65m.cli.dreamer
import ig65m.cli.vgan


parser = argparse.ArgumentParser(prog="ig65m")
Expand Down Expand Up @@ -50,5 +51,17 @@
dreamer.set_defaults(main=ig65m.cli.dreamer.main)


vgan = subcmd.add_parser("vgan", help="🥑 video generative adversarial network", formatter_class=Formatter)
vgan.add_argument("videos", type=Path, help="directory to read videos from")
vgan.add_argument("--checkpoints", type=Path, required=True, help="directory to save checkpoints to")
vgan.add_argument("--num-epochs", type=int, default=100, help="number of epochs to run through dataset")
vgan.add_argument("--batch-size", type=int, default=1, help="number of clips per batch")
vgan.add_argument("--clip-length", type=int, default=32, help="number of frames per clip")
vgan.add_argument("--z-dimension", type=int, default=128, help="noise dimensionality")
vgan.add_argument("--save-frequency", type=int, default=100, help="number of steps to checkpoint after")
vgan.add_argument("--logs", type=Path, required=True, help="directory to save TensorBoard logs to")
vgan.set_defaults(main=ig65m.cli.vgan.main)


args = parser.parse_args()
args.main(args)
161 changes: 161 additions & 0 deletions ig65m/cli/vgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import sys

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from torchvision.transforms import Compose

from einops import rearrange
from einops.layers.torch import Rearrange

from ig65m.datasets import VideoDirectoryDataset
from ig65m.transforms import ToTensor, Resize, CenterCrop, Normalize, Denormalize
from ig65m.losses import GeneratorHingeLoss, DiscriminatorHingeLoss
from ig65m.gan import Generator, Discriminator


def main(args):
if torch.cuda.is_available():
print("🐎 Running on GPU(s)", file=sys.stderr)
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True
else:
print("🐌 Running on CPU(s)", file=sys.stderr)
device = torch.device("cpu")

mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]

transform = Compose([
ToTensor(),
Rearrange("t h w c -> c t h w"),
#Resize(48),
CenterCrop(32),
Normalize(mean=mean, std=std),
])

denormalize = Denormalize(mean=mean, std=std)

dataset = VideoDirectoryDataset(args.videos, clip_length=args.clip_length, transform=transform)
loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0)

args.checkpoints.mkdir(exist_ok=True)

g = Generator(args.z_dimension)
g = g.to(device)
g = nn.DataParallel(g)

d = Discriminator()
d = d.to(device)
d = nn.DataParallel(d)

lr_g = 1e-4 * 1
lr_d = 1e-4 * 4

opt_g = torch.optim.Adam([p for p in g.parameters() if p.requires_grad],
lr=lr_g, betas=(0, 0.9))

opt_d = torch.optim.Adam([p for p in d.parameters() if p.requires_grad],
lr=lr_d, betas=(0, 0.9))

crit_g = GeneratorHingeLoss()
crit_d = DiscriminatorHingeLoss()

zfix = torch.randn(1, args.z_dimension, device=device).clamp_(0)

step = 0

with SummaryWriter(str(args.logs)) as summary:
for _ in range(args.num_epochs):
for inputs in loader:
adjust_learning_rate(opt_g, step, lr_g)
adjust_learning_rate(opt_d, step, lr_d)

# Step D

g.zero_grad()
d.zero_grad()

z = torch.randn(inputs.size(0), args.z_dimension, device=device).clamp_(0)

real_data = inputs.to(device)
fake_data = g(z)

real_out = d(real_data)
fake_out = d(fake_data)

loss_d_real, loss_d_fake = crit_d(real_out, fake_out)
loss_d = loss_d_real.mean() + loss_d_fake.mean()
loss_d.backward()

opt_d.step()

# Step G

g.zero_grad()
d.zero_grad()

z = torch.randn(inputs.size(0), args.z_dimension, device=device).clamp_(0)

fake_data = g(z)
fake_out = d(fake_data)

loss_g = crit_g(fake_out)
loss_g.backward()

opt_g.step()

# Done

summary.add_scalar("Loss/Discriminator/Real", loss_d_real.item(), step)
summary.add_scalar("Loss/Discriminator/Fake", loss_d_fake.item(), step)
summary.add_scalar("Loss/Generator", loss_g.item(), step)

if step % args.save_frequency == 0:
real_data = inputs
real_clip = denormalize(real_data[0])
real_images = rearrange(real_clip, "c t h w -> t c h w")

summary.add_images("Images/Real", real_images, step)

with torch.no_grad():
for m in g.modules():
if isinstance(m, nn.BatchNorm3d):
m.eval()

fake_data = g(zfix)

for m in g.modules():
if isinstance(m, nn.BatchNorm3d):
m.train()

fake_clip = denormalize(fake_data[0])
fake_images = rearrange(fake_clip, "c t h w -> t c h w")

summary.add_images("Images/Fake", fake_images, step)

state = {"step": step,
"g_state_dict": g.state_dict(), "d_state_dict": d.state_dict(),
"g_opt": opt_g.state_dict(), "d_opt": opt_d.state_dict()}

torch.save(state, args.checkpoints / "state-{:010d}.pth".format(step))

step += 1

print("🥑 Done", file=sys.stderr)


# https://arxiv.org/abs/1706.02677
def adjust_learning_rate(optimizer, step, lr):
warmup = 1000
base = 0.01 * lr

def lerp(c, first, last):
return first + c * (last - first)

if step <= warmup:
lr = lerp(step / warmup, base, lr)

for param_group in optimizer.param_groups:
param_group["lr"] = lr
Loading

0 comments on commit 6ac0b2b

Please sign in to comment.