diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 8fb7b9642..012f1bbd8 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -507,7 +507,9 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) + sample_sigmas = model_options["transformer_options"]["sample_sigmas"] + current_timestep = timestep[0].to(sample_sigmas.dtype) + mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: return # substep from multi-step sampler: keep self._step from the last full step