Skip to content

Commit

Permalink
fix precision of DPM solver in bfloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
catwell committed Sep 26, 2024
1 parent ed7e2e5 commit 883a212
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/refiners/foundationals/latent_diffusion/solvers/dpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
first_inference_step=first_inference_step,
params=params,
device=device,
dtype=dtype,
dtype=torch.float64, # compute constants precisely
)
self.estimated_data = deque([torch.tensor([])] * 2, maxlen=2)
self.last_step_first_order = last_step_first_order
Expand All @@ -89,6 +89,7 @@ def __init__(
self.sigmas
)
self.timesteps = self._timesteps_from_sigmas(sigmas)
self.to(dtype=dtype)

def rebuild(
self: "DPMSolver",
Expand Down Expand Up @@ -131,12 +132,9 @@ def _rescale_sigmas(self, sigmas: torch.Tensor, sigma_schedule: NoiseSchedule |
case NoiseSchedule.KARRAS:
rho = 7
case None:
if sigmas.dtype == torch.bfloat16:
sigmas = sigmas.to(torch.float32)
return torch.tensor(
np.interp(self.timesteps.cpu(), np.arange(0, len(sigmas)), sigmas.cpu()),
device=self.device,
dtype=self.dtype,
)

linear_schedule = torch.linspace(0, 1, steps=self.num_inference_steps, device=self.device)
Expand Down

0 comments on commit 883a212

Please sign in to comment.