mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-02 13:22:32 +08:00
Create separate latent_retain_index_list to ensure that inplace latent retention doesn't occur on all models when unintended.
This commit is contained in:
parent
6442392810
commit
4e434bccaa
@ -314,7 +314,8 @@ class ContextFuseMethod:
|
|||||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||||
class IndexListContextHandler(ContextHandlerABC):
|
class IndexListContextHandler(ContextHandlerABC):
|
||||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
|
||||||
|
latent_retain_index_list: list[int]=[]):
|
||||||
self.context_schedule = context_schedule
|
self.context_schedule = context_schedule
|
||||||
self.fuse_method = fuse_method
|
self.fuse_method = fuse_method
|
||||||
self.context_length = context_length
|
self.context_length = context_length
|
||||||
@ -326,6 +327,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
self.freenoise = freenoise
|
self.freenoise = freenoise
|
||||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||||
self.split_conds_to_windows = split_conds_to_windows
|
self.split_conds_to_windows = split_conds_to_windows
|
||||||
|
self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else []
|
||||||
|
|
||||||
self.callbacks = {}
|
self.callbacks = {}
|
||||||
|
|
||||||
@ -415,6 +417,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
|
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
|
||||||
if self.cond_retain_index_list:
|
if self.cond_retain_index_list:
|
||||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||||
|
if self.latent_retain_index_list:
|
||||||
|
logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}")
|
||||||
return True
|
return True
|
||||||
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
|
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
|
||||||
return False
|
return False
|
||||||
@ -609,7 +613,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
window = window_state.prepare_window(window, model)
|
window = window_state.prepare_window(window, model)
|
||||||
|
|
||||||
# slice the window for each modality, injecting guide frames where applicable
|
# slice the window for each modality, injecting guide frames where applicable
|
||||||
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.cond_retain_index_list, device)
|
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device)
|
||||||
|
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||||
|
|||||||
@ -14,21 +14,22 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
description="Manually set context windows.",
|
description="Manually set context windows.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||||
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True),
|
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
|
||||||
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True),
|
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
|
||||||
io.Combo.Input("context_schedule", options=[
|
io.Combo.Input("context_schedule", options=[
|
||||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||||
comfy.context_windows.ContextSchedules.BATCHED,
|
comfy.context_windows.ContextSchedules.BATCHED,
|
||||||
], tooltip="The stride of the context window."),
|
], tooltip="The stride of the context window."),
|
||||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."),
|
||||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||||
|
io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||||
@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[]) -> io.Model:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||||
@ -50,7 +51,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
dim=dim,
|
dim=dim,
|
||||||
freenoise=freenoise,
|
freenoise=freenoise,
|
||||||
cond_retain_index_list=cond_retain_index_list,
|
cond_retain_index_list=cond_retain_index_list,
|
||||||
split_conds_to_windows=split_conds_to_windows
|
split_conds_to_windows=split_conds_to_windows,
|
||||||
|
latent_retain_index_list=latent_retain_index_list
|
||||||
)
|
)
|
||||||
# make memory usage calculation only take into account the context window latents
|
# make memory usage calculation only take into account the context window latents
|
||||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user