diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 012f1bbd8..c927c2f00 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -314,7 +314,8 @@ class ContextFuseMethod: ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) class IndexListContextHandler(ContextHandlerABC): def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, - closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False): + closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, + latent_retain_index_list: list[int]=[]): self.context_schedule = context_schedule self.fuse_method = fuse_method self.context_length = context_length @@ -326,6 +327,7 @@ class IndexListContextHandler(ContextHandlerABC): self.freenoise = freenoise self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] self.split_conds_to_windows = split_conds_to_windows + self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else [] self.callbacks = {} @@ -415,6 +417,8 @@ class IndexListContextHandler(ContextHandlerABC): logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.") if self.cond_retain_index_list: logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") + if self.latent_retain_index_list: + logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}") return True logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).") return False @@ -609,7 +613,7 @@ class IndexListContextHandler(ContextHandlerABC): window = window_state.prepare_window(window, model) # slice the window for each modality, injecting guide frames where applicable - sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.cond_retain_index_list, device) + sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device) for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 0e43f2e44..dcae65dcc 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -14,21 +14,22 @@ class ContextWindowsManualNode(io.ComfyNode): description="Manually set context windows.", inputs=[ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), - io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True), - io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True), + io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."), + io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."), io.Combo.Input("context_schedule", options=[ comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.BATCHED, ], tooltip="The stride of the context window."), - io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."), io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."), ], outputs=[ io.Model.Output(tooltip="The model with context windows applied during sampling."), @@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode): @classmethod def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, - cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[]) -> io.Model: model = model.clone() model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), @@ -50,7 +51,8 @@ class ContextWindowsManualNode(io.ComfyNode): dim=dim, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, - split_conds_to_windows=split_conds_to_windows + split_conds_to_windows=split_conds_to_windows, + latent_retain_index_list=latent_retain_index_list ) # make memory usage calculation only take into account the context window latents comfy.context_windows.create_prepare_sampling_wrapper(model)