Allow splitting multiple conds into different windows

This commit is contained in:
kijai 2025-11-23 17:06:00 +02:00
parent dcf721b772
commit a8aae35afd
2 changed files with 25 additions and 7 deletions

View File

@ -51,10 +51,12 @@ class ContextHandlerABC(ABC):
class IndexListContextWindow(ContextWindowABC): class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0): def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
self.index_list = index_list self.index_list = index_list
self.context_length = len(index_list) self.context_length = len(index_list)
self.dim = dim self.dim = dim
self.total_frames = total_frames
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
if dim is None: if dim is None:
@ -75,6 +77,10 @@ class IndexListContextWindow(ContextWindowABC):
full[idx] += to_add full[idx] += to_add
return full return full
def get_region_index(self, num_regions: int) -> int:
region_idx = int(self.center_ratio * num_regions)
return min(max(region_idx, 0), num_regions - 1)
class IndexListCallbacks: class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@ -99,7 +105,7 @@ class ContextFuseMethod:
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC): class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[]): closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
self.context_schedule = context_schedule self.context_schedule = context_schedule
self.fuse_method = fuse_method self.fuse_method = fuse_method
self.context_length = context_length self.context_length = context_length
@ -110,6 +116,7 @@ class IndexListContextHandler(ContextHandlerABC):
self._step = 0 self._step = 0
self.freenoise = freenoise self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows
self.callbacks = {} self.callbacks = {}
@ -132,6 +139,11 @@ class IndexListContextHandler(ContextHandlerABC):
return None return None
# reuse or resize cond items to match context requirements # reuse or resize cond items to match context requirements
resized_cond = [] resized_cond = []
# if multiple conds, split based on primary region
if self.split_conds_to_windows and len(cond_in) > 1:
region = window.get_region_index(len(cond_in))
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it # cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in: for actual_cond in cond_in:
resized_actual_cond = actual_cond.copy() resized_actual_cond = actual_cond.copy()
@ -184,7 +196,7 @@ class IndexListContextHandler(ContextHandlerABC):
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options) context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows] context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
return context_windows return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):

View File

@ -28,6 +28,7 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
], ],
outputs=[ outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."), io.Model.Output(tooltip="The model with context windows applied during sampling."),
@ -36,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, cond_retain_index_list: list[int]=[]) -> io.Model: def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
model = model.clone() model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@ -47,7 +49,9 @@ class ContextWindowsManualNode(io.ComfyNode):
closed_loop=closed_loop, closed_loop=closed_loop,
dim=dim, dim=dim,
freenoise=freenoise, freenoise=freenoise,
cond_retain_index_list=cond_retain_index_list) cond_retain_index_list=cond_retain_index_list,
split_conds_to_windows=split_conds_to_windows
)
# make memory usage calculation only take into account the context window latents # make memory usage calculation only take into account the context window latents
comfy.context_windows.create_prepare_sampling_wrapper(model) comfy.context_windows.create_prepare_sampling_wrapper(model)
if freenoise: # no other use for this wrapper at this time if freenoise: # no other use for this wrapper at this time
@ -76,14 +80,16 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
] ]
return schema return schema
@classmethod @classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, cond_retain_index_list: list[int]=[]) -> io.Model: def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list) return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
class ContextWindowsExtension(ComfyExtension): class ContextWindowsExtension(ComfyExtension):