WGANGP not working as expected #7754
Answered
by
ncuxomun
ncuxomun
asked this question in
code help: CV
-
Hi, can you please tell me if I am doing something wrong here, especially the manual update for the generator and the critic: class Generator(nn.Module):
def __init__(self, latent_dim=64, img_shape=None):
super().__init__()
self.img_shape = img_shape
self.init_size = 8 #self.img_shape[1] // 4
self.l1 = nn.Sequential(
nn.Linear(latent_dim, 64*self.init_size**2), nn.LeakyReLU(0.2, inplace=True))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(64),
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=0),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=0),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=0),
nn.BatchNorm2d(16),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3, padding=1),
nn.BatchNorm2d(8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=8, out_channels=img_shape[0], kernel_size=3, padding=1),
nn.Tanh()
)
def forward(self, z):
out = self.l1(z)
out = out.view(out.shape[0], 64, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
class Critic(nn.Module):
def __init__(self, img_shape):
super().__init__()
self.disc = nn.Sequential(
nn.Conv2d(in_channels=img_shape[0], out_channels=16, kernel_size=4, stride=2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(16, 32, kernel_size=4, stride=2),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
)
# The height and width of downsampled image
#
ds_size = 2 ** 4
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size, 1))
def forward(self, img):
out = self.disc(img)
# import pdb; pdb.set_trace()
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
return validity
class WGANGP(pl.LightningModule):
def __init__(self, latent_dim=128, lr=0.0002, lambda_pen=10, crit_repeats=5):
super().__init__()
self.save_hyperparameters()
self.latent_dim = latent_dim
self.lr = lr
self.lambda_pen = lambda_pen
self.crit_repeats = crit_repeats
self.b1 = 0.0
self.b2 = 0.9
### initializing networks
img_shape = (1, 100, 100)
self.generator = Generator(self.latent_dim, img_shape)
self.critic = Critic(img_shape)
# application of weight
self.generator.apply(self.weights_init)
self.critic.apply(self.weights_init)
#
self.validation_z = torch.randn(10, self.latent_dim)
self.example_input_array = torch.zeros(10, self.latent_dim)
# Important: This property activates manual optimization.
self.automatic_optimization = False # True - Auto // # False - Manual update
def forward(self, z):
return self.generator(z)
### weight initialization
def weights_init(self, m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)
def training_step(self, batch, batch_idx):
imgs = batch
# # sample noise
# z = torch.randn(imgs.shape[0], self.latent_dim)
# z = z.type_as(imgs)
# optimizers, manual access
g_opt, c_opt = self.optimizers()
# update critic
mean_iteration_critic_loss = 0
for _ in range(self.crit_repeats):
c_opt.zero_grad()
# sample noise
z = torch.randn(imgs.shape[0], self.latent_dim).type_as(imgs)
# fake image
fake = self(z)
crit_fake_pred = self.critic(fake.detach())
crit_real_pred = self.critic(imgs)
# eps
epsilon = torch.rand(len(imgs), 1, 1, 1, device=self.device, requires_grad=True)
# gradient penalty
gp = self.gradient_penalty(self.critic, imgs, fake, epsilon)
# critic loss
critic_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + self.lambda_pen * gp
# Keep track of the average critic loss in this batch
mean_iteration_critic_loss += critic_loss.item() / crit_repeats
# Update gradients
self.manual_backward(critic_loss)
# Update optimizer
c_opt.step()
# log critic average loss
self.log('c_loss_mean', mean_iteration_critic_loss, prog_bar=True)
# update generator
g_opt.zero_grad()
# sample new noise
z_new = torch.randn(imgs.shape[0], self.latent_dim).type_as(imgs)
# new fake image
fake_new = self(z_new)
crit_fake_pred = self.critic(fake_new)
# generator loss
gen_loss = -torch.mean(crit_fake_pred)
# Update gradients
self.manual_backward(gen_loss)
# Update optimizer
g_opt.step()
# log generator average loss
self.log('g_loss', gen_loss, prog_bar=True)
def gradient_penalty(self, crit, real, fake, epsilon):
# mix/interpolate images
mixed_images = real * epsilon + fake * (1 - epsilon)
# Calculate the critic's scores on the mixed images
mixed_scores = crit(mixed_images)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=mixed_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
# Flatten the gradients so that each row captures one image
gradient = gradient.view(len(gradient), -1)
# Calculate the magnitude of every row
gradient_norm = gradient.norm(2, dim=1)
# Penalize the mean squared distance of the gradient norms from 1
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
def configure_optimizers(self):
opt_g = torch.optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.b1, self.b2))
opt_c = torch.optim.Adam(self.critic.parameters(), lr=self.lr, betas=(self.b1, self.b2))
return opt_g, opt_c
def on_epoch_end(self):
z = self.validation_z.to(self.device)
# log sampled images
sample_imgs = self(z)
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
# defining the hyperparameters
n_epochs = 1000
z_dim = 50
batch_size = 64
lr = 0.0002
c_lambda = 10
crit_repeats = 5 SOLVED |
Beta Was this translation helpful? Give feedback.
Answered by
ncuxomun
May 30, 2021
Replies: 1 comment 2 replies
Answer selected by
akihironitta
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Solved.