mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 21:00:16 +08:00
Allow freenoise to work on other dims, handle 4D batch timestep
Refactor Freenoise function. And fix batch handling as timesteps seem to be expanded to batch size now.
This commit is contained in:
parent
446b086ef3
commit
cf1e9885f4
@ -192,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
return resized_cond
|
return resized_cond
|
||||||
|
|
||||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||||
matches = torch.nonzero(mask)
|
matches = torch.nonzero(mask)
|
||||||
if torch.numel(matches) == 0:
|
if torch.numel(matches) == 0:
|
||||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||||
@ -324,7 +324,7 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
|
|||||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||||
if not handler.freenoise:
|
if not handler.freenoise:
|
||||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||||
noise = apply_freenoise(noise, handler.context_length, handler.context_overlap, extra_args["seed"])
|
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||||
|
|
||||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||||
|
|
||||||
@ -591,24 +591,26 @@ def shift_window_to_end(window: list[int], num_frames: int):
|
|||||||
|
|
||||||
|
|
||||||
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
|
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
|
||||||
def apply_freenoise(noise: torch.Tensor, context_length: int, context_overlap: int, seed: int):
|
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
|
||||||
logging.info("Context windows: Applying FreeNoise")
|
logging.info("Context windows: Applying FreeNoise")
|
||||||
generator = torch.manual_seed(seed)
|
generator = torch.Generator(device='cpu').manual_seed(seed)
|
||||||
latent_video_length = noise.shape[2]
|
latent_video_length = noise.shape[dim]
|
||||||
delta = context_length - context_overlap
|
delta = context_length - context_overlap
|
||||||
for start_idx in range(0, latent_video_length-context_length, delta):
|
|
||||||
place_idx = start_idx + context_length
|
|
||||||
if place_idx >= latent_video_length:
|
|
||||||
break
|
|
||||||
end_idx = place_idx - 1
|
|
||||||
|
|
||||||
if end_idx + delta >= latent_video_length:
|
for start_idx in range(0, latent_video_length - context_length, delta):
|
||||||
final_delta = latent_video_length - place_idx
|
place_idx = start_idx + context_length
|
||||||
list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long)
|
|
||||||
list_idx = list_idx[torch.randperm(final_delta, generator=generator)]
|
actual_delta = min(delta, latent_video_length - place_idx)
|
||||||
noise[:, :, place_idx:place_idx + final_delta] = noise[:, :, list_idx]
|
if actual_delta <= 0:
|
||||||
break
|
break
|
||||||
list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long)
|
|
||||||
list_idx = list_idx[torch.randperm(delta, generator=generator)]
|
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
|
||||||
noise[:, :, place_idx:place_idx + delta] = noise[:, :, list_idx]
|
|
||||||
|
source_slice = [slice(None)] * noise.ndim
|
||||||
|
source_slice[dim] = list_idx
|
||||||
|
target_slice = [slice(None)] * noise.ndim
|
||||||
|
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
|
||||||
|
|
||||||
|
noise[tuple(target_slice)] = noise[tuple(source_slice)]
|
||||||
|
|
||||||
return noise
|
return noise
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user