From cf1e9885f4cf0cdb2fe5d499d4c56a2b25499f34 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 3 Dec 2025 19:16:26 +0200 Subject: [PATCH] Allow freenoise to work on other dims, handle 4D batch timestep Refactor Freenoise function. And fix batch handling as timesteps seem to be expanded to batch size now. --- comfy/context_windows.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index cecbe51ad..5c412d1c2 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -192,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001) + mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: raise Exception("No sample_sigmas matched current timestep; something went wrong.") @@ -324,7 +324,7 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") if not handler.freenoise: return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - noise = apply_freenoise(noise, handler.context_length, handler.context_overlap, extra_args["seed"]) + noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) @@ -591,24 +591,26 @@ def shift_window_to_end(window: list[int], num_frames: int): # https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465 -def apply_freenoise(noise: torch.Tensor, context_length: int, context_overlap: int, seed: int): +def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int): logging.info("Context windows: Applying FreeNoise") - generator = torch.manual_seed(seed) - latent_video_length = noise.shape[2] + generator = torch.Generator(device='cpu').manual_seed(seed) + latent_video_length = noise.shape[dim] delta = context_length - context_overlap - for start_idx in range(0, latent_video_length-context_length, delta): - place_idx = start_idx + context_length - if place_idx >= latent_video_length: - break - end_idx = place_idx - 1 - if end_idx + delta >= latent_video_length: - final_delta = latent_video_length - place_idx - list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long) - list_idx = list_idx[torch.randperm(final_delta, generator=generator)] - noise[:, :, place_idx:place_idx + final_delta] = noise[:, :, list_idx] + for start_idx in range(0, latent_video_length - context_length, delta): + place_idx = start_idx + context_length + + actual_delta = min(delta, latent_video_length - place_idx) + if actual_delta <= 0: break - list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long) - list_idx = list_idx[torch.randperm(delta, generator=generator)] - noise[:, :, place_idx:place_idx + delta] = noise[:, :, list_idx] + + list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx + + source_slice = [slice(None)] * noise.ndim + source_slice[dim] = list_idx + target_slice = [slice(None)] * noise.ndim + target_slice[dim] = slice(place_idx, place_idx + actual_delta) + + noise[tuple(target_slice)] = noise[tuple(source_slice)] + return noise