Add defensive dtype cast before sigma step check

This commit is contained in:
ozbayb 2026-04-13 15:19:25 -06:00
parent 6a53695006
commit 6442392810

View File

@ -507,7 +507,9 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
current_timestep = timestep[0].to(sample_sigmas.dtype)
mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
return # substep from multi-step sampler: keep self._step from the last full step