Add FreeNoise

This commit is contained in:
kijai 2025-11-22 15:06:25 +02:00
parent 1cd5b90385
commit 928fd4d13f
2 changed files with 61 additions and 10 deletions

View File

@ -61,13 +61,13 @@ class IndexListContextWindow(ContextWindowABC):
dim = self.dim dim = self.dim
if dim == 0 and full.shape[dim] == 1: if dim == 0 and full.shape[dim] == 1:
return full return full
idx = [slice(None)] * dim + [self.index_list] idx = tuple([slice(None)] * dim + [self.index_list])
return full[idx].to(device) return full[idx].to(device)
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor: def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
if dim is None: if dim is None:
dim = self.dim dim = self.dim
idx = [slice(None)] * dim + [self.index_list] idx = tuple([slice(None)] * dim + [self.index_list])
full[idx] += to_add full[idx] += to_add
return full return full
@ -94,7 +94,7 @@ class ContextFuseMethod:
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC): class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0): def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop: bool=False, dim:int=0, freenoise: bool=False):
self.context_schedule = context_schedule self.context_schedule = context_schedule
self.fuse_method = fuse_method self.fuse_method = fuse_method
self.context_length = context_length self.context_length = context_length
@ -103,13 +103,14 @@ class IndexListContextHandler(ContextHandlerABC):
self.closed_loop = closed_loop self.closed_loop = closed_loop
self.dim = dim self.dim = dim
self._step = 0 self._step = 0
self.freenoise = freenoise
self.callbacks = {} self.callbacks = {}
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation # for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length: if x_in.size(self.dim) > self.context_length:
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.") logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
return True return True
return False return False
@ -252,8 +253,8 @@ class IndexListContextHandler(ContextHandlerABC):
prev_weight = (bias_total / (bias_total + bias)) prev_weight = (bias_total / (bias_total + bias))
new_weight = (bias / (bias_total + bias)) new_weight = (bias / (bias_total + bias))
# account for dims of tensors # account for dims of tensors
idx_window = [slice(None)] * self.dim + [idx] idx_window = tuple([slice(None)] * self.dim + [idx])
pos_window = [slice(None)] * self.dim + [pos] pos_window = tuple([slice(None)] * self.dim + [pos])
# apply new values # apply new values
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
biases_final[i][idx] = bias_total + bias biases_final[i][idx] = bias_total + bias
@ -289,6 +290,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
) )
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
model_options = extra_args.get("model_options", None)
if model_options is None:
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is None:
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
if not handler.freenoise:
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
noise = apply_freenoise(noise, handler.context_length, handler.context_overlap, extra_args["seed"])
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
def create_sampler_sample_wrapper(model: ModelPatcher):
model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
"ContextWindows_sampler_sample",
_sampler_sample_wrapper
)
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape) total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device) weights_tensor = torch.Tensor(weights).to(device=device)
@ -540,3 +563,26 @@ def shift_window_to_end(window: list[int], num_frames: int):
for i in range(len(window)): for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end # 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta window[i] = window[i] + end_delta
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
def apply_freenoise(noise: torch.Tensor, context_length: int, context_overlap: int, seed: int):
logging.info(f"Context windows: Applying FreeNoise")
generator = torch.manual_seed(seed)
latent_video_length = noise.shape[2]
delta = context_length - context_overlap
for start_idx in range(0, latent_video_length-context_length, delta):
place_idx = start_idx + context_length
if place_idx >= latent_video_length:
break
end_idx = place_idx - 1
if end_idx + delta >= latent_video_length:
final_delta = latent_video_length - place_idx
list_idx = torch.tensor(list(range(start_idx,start_idx+final_delta)), device=torch.device("cpu"), dtype=torch.long)
list_idx = list_idx[torch.randperm(final_delta, generator=generator)]
noise[:, :, place_idx:place_idx + final_delta] = noise[:, :, list_idx]
break
list_idx = torch.tensor(list(range(start_idx,start_idx+delta)), device=torch.device("cpu"), dtype=torch.long)
list_idx = list_idx[torch.randperm(delta, generator=generator)]
noise[:, :, place_idx:place_idx + delta] = noise[:, :, list_idx]
return noise

View File

@ -26,6 +26,7 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
], ],
outputs=[ outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."), io.Model.Output(tooltip="The model with context windows applied during sampling."),
@ -34,7 +35,7 @@ class ContextWindowsManualNode(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model: def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool) -> io.Model:
model = model.clone() model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@ -43,9 +44,12 @@ class ContextWindowsManualNode(io.ComfyNode):
context_overlap=context_overlap, context_overlap=context_overlap,
context_stride=context_stride, context_stride=context_stride,
closed_loop=closed_loop, closed_loop=closed_loop,
dim=dim) dim=dim,
freenoise=freenoise)
# make memory usage calculation only take into account the context window latents # make memory usage calculation only take into account the context window latents
comfy.context_windows.create_prepare_sampling_wrapper(model) comfy.context_windows.create_prepare_sampling_wrapper(model)
if freenoise: # no other use for this wrapper at this time
comfy.context_windows.create_sampler_sample_wrapper(model)
return io.NodeOutput(model) return io.NodeOutput(model)
class WanContextWindowsManualNode(ContextWindowsManualNode): class WanContextWindowsManualNode(ContextWindowsManualNode):
@ -68,14 +72,15 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
] ]
return schema return schema
@classmethod @classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model: def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2) return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise)
class ContextWindowsExtension(ComfyExtension): class ContextWindowsExtension(ComfyExtension):