Skip to content

Commit

Permalink
quick fix for interpolate in cfg example
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 30, 2023
1 parent daf2d28 commit d83751e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
14 changes: 12 additions & 2 deletions denoising_diffusion_pytorch/classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def sample(self, classes, cond_scale = 6., rescaled_phi = 0.7):
return sample_fn(classes, (batch_size, channels, image_size, image_size), cond_scale, rescaled_phi)

@torch.no_grad()
def interpolate(self, x1, x2, t = None, lam = 0.5):
def interpolate(self, x1, x2, classes, t = None, lam = 0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)

Expand All @@ -709,8 +709,9 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))

img = (1 - lam) * xt1 + lam * xt2

for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
img, _ = self.p_sample(img, i, classes)

return img

Expand Down Expand Up @@ -795,3 +796,12 @@ def forward(self, img, *args, **kwargs):
)

sampled_images.shape # (8, 3, 128, 128)

# interpolation

interpolate_out = diffusion.interpolate(
training_images[:1],
training_images[:1],
image_classes[:1]
)

2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.8.14'
__version__ = '1.8.15'

0 comments on commit d83751e

Please sign in to comment.