Skip to content

Commit aa2ce41

Browse files
NotNANtoNanton-l
andauthored
Fix img2img speed with LMS-Discrete Scheduler (#896)
Casting `self.sigmas` into a different dtype (the one of original_samples) is not advisable. In my img2img pipeline this leads to a long running time in the `integrate.quad` call later on- by long I mean more than 10x slower. Co-authored-by: Anton Lozhkov <[email protected]>
1 parent 81fa2d6 commit aa2ce41

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,19 +243,18 @@ def add_noise(
243243
timesteps: torch.FloatTensor,
244244
) -> torch.FloatTensor:
245245
# Make sure sigmas and timesteps have the same device and dtype as original_samples
246-
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
246+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
247247
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
248248
# mps does not support float64
249-
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
249+
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
250250
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
251251
else:
252-
self.timesteps = self.timesteps.to(original_samples.device)
252+
schedule_timesteps = self.timesteps.to(original_samples.device)
253253
timesteps = timesteps.to(original_samples.device)
254254

255-
schedule_timesteps = self.timesteps
256255
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
257256

258-
sigma = self.sigmas[step_indices].flatten()
257+
sigma = sigmas[step_indices].flatten()
259258
while len(sigma.shape) < len(original_samples.shape):
260259
sigma = sigma.unsqueeze(-1)
261260

0 commit comments

Comments
 (0)