Skip to content

Commit

Permalink
Adds video GAN framework
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-j-h committed Oct 26, 2019
1 parent 6f426a6 commit ee270bb
Show file tree
Hide file tree
Showing 13 changed files with 670 additions and 55 deletions.
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
__pycache__
*.py[cod]

assets

*.pth
*.pb
*.pkl
2 changes: 1 addition & 1 deletion Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY requirements.txt .

RUN python3 -m venv /opt/venv && \
python3 -m pip install pip==19.2.3 pip-tools==4.0.0
python3 -m pip install pip==19.2.3 pip-tools==4.0.0 setuptools==41.4.0

RUN echo "https://download.pytorch.org/whl/cpu/torch-1.3.0%2Bcpu-cp36-cp36m-linux_x86_64.whl \
--hash=sha256:ce648bb0c6b86dd99a8b5598ae6362a066cca8de69ad089cd206ace3bdec0a5f \
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY requirements.txt .

RUN python3 -m venv /opt/venv && \
python3 -m pip install pip==19.2.3 pip-tools==4.0.0
python3 -m pip install pip==19.2.3 pip-tools==4.0.0 setuptools==41.4.0

RUN echo "https://download.pytorch.org/whl/cu100/torch-1.3.0%2Bcu100-cp36-cp36m-linux_x86_64.whl \
--hash=sha256:2414744c5f9fc25e4ee181019df188b0ea28c7866ce7af13116c4d7e538460b7 \
Expand Down
101 changes: 101 additions & 0 deletions ig65m/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import torch.nn as nn

from einops import rearrange


# Attention: start with this paper
# https://arxiv.org/abs/1904.11492


class SelfAttention3d(nn.Module):
def __init__(self, planes):
super().__init__()

# Note: ratios below should be made configurable

self.q = nn.Conv3d(planes, planes // 8, kernel_size=1, bias=False)
self.k = nn.Conv3d(planes, planes // 8, kernel_size=1, bias=False)
self.v = nn.Conv3d(planes, planes // 2, kernel_size=1, bias=False)
self.z = nn.Conv3d(planes // 2, planes, kernel_size=1, bias=False)

self.y = nn.Parameter(torch.tensor(0.))

def forward(self, x):
q = self.q(x)
k = self.k(x)
v = self.v(x)

# Note: pooling below should be made configurable

k = nn.functional.max_pool3d(k, (2, 2, 2))
v = nn.functional.max_pool3d(v, (2, 2, 2))

q = rearrange(q, "n c t h w -> n (t h w) c")
k = rearrange(k, "n c t h w -> n c (t h w)")
v = rearrange(v, "n c t h w -> n c (t h w)")

beta = torch.bmm(q, k)
beta = torch.softmax(beta, dim=-1)
beta = rearrange(beta, "n thw c -> n c thw")

att = torch.bmm(v, beta)
att = rearrange(att, "n c (t h w) -> n c t h w",
t=x.size(2), h=x.size(3), w=x.size(4))

return self.y * self.z(att) + x


class SimpleSelfAttention3d(nn.Module):
def __init__(self, planes):
super().__init__()

self.k = nn.Conv3d(planes, 1, kernel_size=1, bias=False)
self.v = nn.Conv3d(planes, planes, kernel_size=1, bias=False)

self.y = nn.Parameter(torch.tensor(0.))

def forward(self, x):
k = self.k(x)
k = rearrange(k, "n c t h w -> n (t h w) c")
k = torch.softmax(k, dim=-1)

xx = rearrange(x, "n c t h w -> n c (t h w)")

ctx = torch.bmm(xx, k)
ctx = rearrange(ctx, "n c () -> n c () () ()")

att = self.v(ctx)

return self.y * att + x


class GlobalContext3d(nn.Module):
def __init__(self, planes):
super().__init__()

self.k = nn.Conv3d(planes, 1, kernel_size=1, bias=False)

# Note: ratios below should be made configurable

self.v = nn.Sequential(
nn.Conv3d(planes, planes // 8, kernel_size=1, bias=False),
nn.LayerNorm((planes // 8, 1, 1, 1)),
nn.ReLU(inplace=True),
nn.Conv3d(planes // 8, planes, kernel_size=1, bias=False))

self.y = nn.Parameter(torch.tensor(0.))

def forward(self, x):
k = self.k(x)
k = rearrange(k, "n c t h w -> n (t h w) c")
k = torch.softmax(k, dim=-1)

xx = rearrange(x, "n c t h w -> n c (t h w)")

ctx = torch.bmm(xx, k)
ctx = rearrange(ctx, "n c () -> n c () () ()")

att = self.v(ctx)

return self.y * att + x
11 changes: 11 additions & 0 deletions ig65m/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ig65m.cli.extract
import ig65m.cli.semcode
import ig65m.cli.dreamer
import ig65m.cli.vgan


parser = argparse.ArgumentParser(prog="ig65m")
Expand Down Expand Up @@ -50,5 +51,15 @@
dreamer.set_defaults(main=ig65m.cli.dreamer.main)


vgan = subcmd.add_parser("vgan", help="🥑 video generative adversarial network", formatter_class=Formatter)
vgan.add_argument("videos", type=Path, help="directory to read videos from")
vgan.add_argument("--num-epochs", type=int, default=100, help="number of epochs to run through dataset")
vgan.add_argument("--batch-size", type=int, default=1, help="number of clips per batch")
vgan.add_argument("--clip-length", type=int, default=32, help="number of frames per clip")
vgan.add_argument("--z-dimension", type=int, default=128, help="noise dimensionality")
vgan.add_argument("--logs", type=Path, required=True, help="directory to save TensorBoard logs to")
vgan.set_defaults(main=ig65m.cli.vgan.main)


args = parser.parse_args()
args.main(args)
122 changes: 122 additions & 0 deletions ig65m/cli/vgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import sys

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from torchvision.transforms import Compose

from einops import rearrange
from einops.layers.torch import Rearrange

from ig65m.datasets import VideoDirectoryDataset
from ig65m.transforms import ToTensor, Resize, CenterCrop, Normalize, Denormalize
from ig65m.losses import GeneratorHingeLoss, DiscriminatorHingeLoss
from ig65m.gan import Generator, Discriminator


def main(args):
if torch.cuda.is_available():
print("🐎 Running on GPU(s)", file=sys.stderr)
device = torch.device("cuda")
torch.backends.cudnn.benchmark = True
else:
print("🐌 Running on CPU(s)", file=sys.stderr)
device = torch.device("cpu")

mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]

transform = Compose([
ToTensor(),
Rearrange("t h w c -> c t h w"),
Resize(48),
CenterCrop(32),
Normalize(mean=mean, std=std),
])

denormalize = Denormalize(mean=mean, std=std)

dataset = VideoDirectoryDataset(args.videos, clip_length=args.clip_length, transform=transform)
loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=0)

g = Generator(args.z_dimension)
g = g.to(device)
g = nn.DataParallel(g)

d = Discriminator()
d = d.to(device)
d = nn.DataParallel(d)

opt_g = torch.optim.Adam([p for p in g.parameters() if p.requires_grad],
lr=1e-4 * 1, betas=(0, 0.9))

opt_d = torch.optim.Adam([p for p in d.parameters() if p.requires_grad],
lr=1e-4 * 4, betas=(0, 0.9))

crit_g = GeneratorHingeLoss()
crit_d = DiscriminatorHingeLoss()

zfix = torch.randn(1, args.z_dimension, device=device)

step = 0

with SummaryWriter(str(args.logs)) as summary:
for _ in range(args.num_epochs):
for inputs in loader:
# Step D

g.zero_grad()
d.zero_grad()

z = torch.randn(inputs.size(0), args.z_dimension, device=device)

real_data = inputs.to(device)
fake_data = g(z)

real_out = d(real_data)
fake_out = d(fake_data)

loss_d_real, loss_d_fake = crit_d(real_out, fake_out)
loss_d = loss_d_real.mean() + loss_d_fake.mean()
loss_d.backward()

opt_d.step()

# Step G

g.zero_grad()
d.zero_grad()

z = torch.randn(inputs.size(0), args.z_dimension, device=device)

fake_data = g(z)
fake_out = d(fake_data)

loss_g = crit_g(fake_out)
loss_g.backward()

opt_g.step()

# Done

summary.add_scalar("Loss/Discriminator/Real", loss_d_real.item(), step)
summary.add_scalar("Loss/Discriminator/Fake", loss_d_fake.item(), step)
summary.add_scalar("Loss/Generator", loss_g.item(), step)

with torch.no_grad():
real_data = inputs
real_clip = denormalize(real_data[0])
real_images = rearrange(real_clip, "c t h w -> t c h w")

summary.add_images("Images/Real", real_images, step)

fake_data = g(zfix)
fake_clip = denormalize(fake_data[0])
fake_images = rearrange(fake_clip, "c t h w -> t c h w")

summary.add_images("Images/Fake", fake_images, step)

step += 1

print("🥑 Done", file=sys.stderr)
47 changes: 38 additions & 9 deletions ig65m/datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import math
import random
import itertools

from torch.utils.data import IterableDataset, get_worker_info

Expand All @@ -10,7 +12,7 @@ def __init__(self, video, first, last):
assert first <= last

for i in range(first):
ret, _ = video.read()
ret = video.grab()

if not ret:
raise RuntimeError("seeking to frame at index {} failed".format(i))
Expand All @@ -20,7 +22,7 @@ def __init__(self, video, first, last):
self.last = last

def __next__(self):
if self.it >= self.last or not self.video.isOpened():
if self.it >= self.last:
raise StopIteration

ok, frame = self.video.read()
Expand Down Expand Up @@ -57,11 +59,11 @@ def __next__(self):


class VideoDataset(IterableDataset):
def __init__(self, path, clip, transform=None):
def __init__(self, path, clip_length, transform=None):
super().__init__()

self.path = path
self.clip = clip
self.clip_length = clip_length
self.transform = transform

video = cv2.VideoCapture(str(path))
Expand All @@ -72,7 +74,7 @@ def __init__(self, path, clip, transform=None):
self.last = frames

def __len__(self):
return self.last // self.clip
return self.last // self.clip_length

def __iter__(self):
info = get_worker_info()
Expand All @@ -95,14 +97,14 @@ def __iter__(self):
else:
fn = lambda v: v # noqa: E731

return TransformedRange(BatchedRange(rng, self.clip), fn)
return TransformedRange(BatchedRange(rng, self.clip_length), fn)


class WebcamDataset(IterableDataset):
def __init__(self, clip, transform=None):
def __init__(self, clip_length, transform=None):
super().__init__()

self.clip = clip
self.clip_length = clip_length
self.transform = transform
self.video = cv2.VideoCapture(0)

Expand All @@ -120,4 +122,31 @@ def __iter__(self):
else:
fn = lambda v: v # noqa: E731

return TransformedRange(BatchedRange(rng, self.clip), fn)
return TransformedRange(BatchedRange(rng, self.clip_length), fn)


class VideoDirectoryDataset(IterableDataset):
def __init__(self, path, clip_length, transform=None):
super().__init__()

self.clip_length = clip_length
self.transform = transform

paths = [p for p in path.iterdir() if p.is_file()]

self.videos = [VideoDataset(p, clip_length, transform) for p in self.paths]
self.total_clips = sum(len(v) for v in self.videos)

def __iter__(self):
info = get_worker_info()

if info is not None:
raise RuntimeError("multiple workers not supported in VideoDirectoryDataset")

random.shuffle(self.videos)

it = itertools.zip_longest(*self.videos)
it = itertools.chain.from_iterable(it)
it = itertools.filterfalse(lambda v: v is None, it)

return it
Loading

0 comments on commit ee270bb

Please sign in to comment.