Create separate latent_retain_index_list to ensure that inplace latent retention doesn't occur on all models when unintended.

This commit is contained in:
ozbayb 2026-04-24 09:47:49 -06:00
parent 6442392810
commit 4e434bccaa
2 changed files with 14 additions and 8 deletions

View File

@ -314,7 +314,8 @@ class ContextFuseMethod:
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC): class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False): closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
latent_retain_index_list: list[int]=[]):
self.context_schedule = context_schedule self.context_schedule = context_schedule
self.fuse_method = fuse_method self.fuse_method = fuse_method
self.context_length = context_length self.context_length = context_length
@ -326,6 +327,7 @@ class IndexListContextHandler(ContextHandlerABC):
self.freenoise = freenoise self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows self.split_conds_to_windows = split_conds_to_windows
self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else []
self.callbacks = {} self.callbacks = {}
@ -415,6 +417,8 @@ class IndexListContextHandler(ContextHandlerABC):
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.") logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
if self.cond_retain_index_list: if self.cond_retain_index_list:
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
if self.latent_retain_index_list:
logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}")
return True return True
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).") logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
return False return False
@ -609,7 +613,7 @@ class IndexListContextHandler(ContextHandlerABC):
window = window_state.prepare_window(window, model) window = window_state.prepare_window(window, model)
# slice the window for each modality, injecting guide frames where applicable # slice the window for each modality, injecting guide frames where applicable
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.cond_retain_index_list, device) sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device)
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)

View File

@ -14,21 +14,22 @@ class ContextWindowsManualNode(io.ComfyNode):
description="Manually set context windows.", description="Manually set context windows.",
inputs=[ inputs=[
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True), io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True), io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
io.Combo.Input("context_schedule", options=[ io.Combo.Input("context_schedule", options=[
comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.STATIC_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
comfy.context_windows.ContextSchedules.BATCHED, comfy.context_windows.ContextSchedules.BATCHED,
], tooltip="The stride of the context window."), ], tooltip="The stride of the context window."),
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."),
], ],
outputs=[ outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."), io.Model.Output(tooltip="The model with context windows applied during sampling."),
@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode):
@classmethod @classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[]) -> io.Model:
model = model.clone() model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@ -50,7 +51,8 @@ class ContextWindowsManualNode(io.ComfyNode):
dim=dim, dim=dim,
freenoise=freenoise, freenoise=freenoise,
cond_retain_index_list=cond_retain_index_list, cond_retain_index_list=cond_retain_index_list,
split_conds_to_windows=split_conds_to_windows split_conds_to_windows=split_conds_to_windows,
latent_retain_index_list=latent_retain_index_list
) )
# make memory usage calculation only take into account the context window latents # make memory usage calculation only take into account the context window latents
comfy.context_windows.create_prepare_sampling_wrapper(model) comfy.context_windows.create_prepare_sampling_wrapper(model)