From e2269b4208a192e094e0f1a85f2d3745977609e3 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 22 Nov 2025 16:23:01 +0200 Subject: [PATCH] Add option to retain condition by indexes for each window This allows for example Wan/HunyuanVideo image to video to "work" by using the initial start frame for each window, otherwise windows beyond first will be pure T2V generations. --- comfy/context_windows.py | 16 ++++++++++++---- comfy_extras/nodes_context_windows.py | 11 +++++++---- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 724a50525..b24d1874e 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -56,13 +56,17 @@ class IndexListContextWindow(ContextWindowABC): self.context_length = len(index_list) self.dim = dim - def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor: + def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: dim = self.dim if dim == 0 and full.shape[dim] == 1: return full idx = tuple([slice(None)] * dim + [self.index_list]) - return full[idx].to(device) + window = full[idx] + if retain_index_list: + idx = tuple([slice(None)] * dim + [retain_index_list]) + window[idx] = full[idx] + return window.to(device) def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor: if dim is None: @@ -94,7 +98,8 @@ class ContextFuseMethod: ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) class IndexListContextHandler(ContextHandlerABC): - def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop: bool=False, dim:int=0, freenoise: bool=False): + def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, + closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[]): self.context_schedule = context_schedule self.fuse_method = fuse_method self.context_length = context_length @@ -104,6 +109,7 @@ class IndexListContextHandler(ContextHandlerABC): self.dim = dim self._step = 0 self.freenoise = freenoise + self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] self.callbacks = {} @@ -111,6 +117,8 @@ class IndexListContextHandler(ContextHandlerABC): # for now, assume first dim is batch - should have stored on BaseModel in actual implementation if x_in.size(self.dim) > self.context_length: logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") + if self.cond_retain_index_list: + logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") return True return False @@ -154,7 +162,7 @@ class IndexListContextHandler(ContextHandlerABC): elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ (cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)): - new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device)) + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list)) elif cond_key == "num_video_frames": # for SVD new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond) new_cond_item[cond_key].cond = window.context_length diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 7ad98f36a..3be1f5132 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -27,6 +27,7 @@ class ContextWindowsManualNode(io.ComfyNode): io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), ], outputs=[ io.Model.Output(tooltip="The model with context windows applied during sampling."), @@ -35,7 +36,7 @@ class ContextWindowsManualNode(io.ComfyNode): ) @classmethod - def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool) -> io.Model: + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, cond_retain_index_list: list[int]=[]) -> io.Model: model = model.clone() model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), @@ -45,7 +46,8 @@ class ContextWindowsManualNode(io.ComfyNode): context_stride=context_stride, closed_loop=closed_loop, dim=dim, - freenoise=freenoise) + freenoise=freenoise, + cond_retain_index_list=cond_retain_index_list) # make memory usage calculation only take into account the context window latents comfy.context_windows.create_prepare_sampling_wrapper(model) if freenoise: # no other use for this wrapper at this time @@ -73,14 +75,15 @@ class WanContextWindowsManualNode(ContextWindowsManualNode): io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), ] return schema @classmethod - def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool) -> io.Model: + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, cond_retain_index_list: list[int]=[]) -> io.Model: context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 - return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise) + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list) class ContextWindowsExtension(ComfyExtension):