Allow freenoise to work on other dims, handle 4D batch timestep

Refactor Freenoise function. And fix batch handling as timesteps seem to be expanded to batch size now.
This commit is contained in:
kijai 2025-12-03 19:16:26 +02:00
parent 446b086ef3
commit cf1e9885f4

View File

@ -192,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]): def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001) mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
matches = torch.nonzero(mask) matches = torch.nonzero(mask)
if torch.numel(matches) == 0: if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.") raise Exception("No sample_sigmas matched current timestep; something went wrong.")
@ -324,7 +324,7 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
if not handler.freenoise: if not handler.freenoise:
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
noise = apply_freenoise(noise, handler.context_length, handler.context_overlap, extra_args["seed"]) noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
@ -591,24 +591,26 @@ def shift_window_to_end(window: list[int], num_frames: int):
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465 # https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
def apply_freenoise(noise: torch.Tensor, context_length: int, context_overlap: int, seed: int): def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
logging.info("Context windows: Applying FreeNoise") logging.info("Context windows: Applying FreeNoise")
generator = torch.manual_seed(seed) generator = torch.Generator(device='cpu').manual_seed(seed)
latent_video_length = noise.shape[2] latent_video_length = noise.shape[dim]
delta = context_length - context_overlap delta = context_length - context_overlap
for start_idx in range(0, latent_video_length-context_length, delta):
place_idx = start_idx + context_length
if place_idx >= latent_video_length:
break
end_idx = place_idx - 1
if end_idx + delta >= latent_video_length: for start_idx in range(0, latent_video_length - context_length, delta):
final_delta = latent_video_length - place_idx place_idx = start_idx + context_length
list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long)
list_idx = list_idx[torch.randperm(final_delta, generator=generator)] actual_delta = min(delta, latent_video_length - place_idx)
noise[:, :, place_idx:place_idx + final_delta] = noise[:, :, list_idx] if actual_delta <= 0:
break break
list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long)
list_idx = list_idx[torch.randperm(delta, generator=generator)] list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
noise[:, :, place_idx:place_idx + delta] = noise[:, :, list_idx]
source_slice = [slice(None)] * noise.ndim
source_slice[dim] = list_idx
target_slice = [slice(None)] * noise.ndim
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
noise[tuple(target_slice)] = noise[tuple(source_slice)]
return noise return noise