mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-07 15:52:32 +08:00
feat: Context windows - add causal_window_fix to improve blending of context windows (CORE-100) (#13563)
* Context windows: add causal_window_fix toggle * Fix slice_cond to correctly handle causal anchor index for temporal offsets
This commit is contained in:
parent
1655f8089a
commit
e5369c0eec
@ -63,7 +63,11 @@ class IndexListContextWindow(ContextWindowABC):
|
|||||||
dim = self.dim
|
dim = self.dim
|
||||||
if dim == 0 and full.shape[dim] == 1:
|
if dim == 0 and full.shape[dim] == 1:
|
||||||
return full
|
return full
|
||||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
indices = self.index_list
|
||||||
|
anchor_idx = getattr(self, 'causal_anchor_index', None)
|
||||||
|
if anchor_idx is not None and anchor_idx >= 0:
|
||||||
|
indices = [anchor_idx] + list(indices)
|
||||||
|
idx = tuple([slice(None)] * dim + [indices])
|
||||||
window = full[idx]
|
window = full[idx]
|
||||||
if retain_index_list:
|
if retain_index_list:
|
||||||
idx = tuple([slice(None)] * dim + [retain_index_list])
|
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||||
@ -113,7 +117,14 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
|
|||||||
|
|
||||||
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
|
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
|
||||||
if temporal_offset > 0:
|
if temporal_offset > 0:
|
||||||
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
|
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||||
|
if anchor_idx is not None and anchor_idx >= 0:
|
||||||
|
# anchor occupies one of the no-cond positions, so skip one fewer from window.index_list
|
||||||
|
skip_count = temporal_offset - 1
|
||||||
|
else:
|
||||||
|
skip_count = temporal_offset
|
||||||
|
|
||||||
|
indices = [i - temporal_offset for i in window.index_list[skip_count:]]
|
||||||
indices = [i for i in indices if 0 <= i]
|
indices = [i for i in indices if 0 <= i]
|
||||||
else:
|
else:
|
||||||
indices = list(window.index_list)
|
indices = list(window.index_list)
|
||||||
@ -150,7 +161,8 @@ class ContextFuseMethod:
|
|||||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||||
class IndexListContextHandler(ContextHandlerABC):
|
class IndexListContextHandler(ContextHandlerABC):
|
||||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
|
||||||
|
causal_window_fix: bool=True):
|
||||||
self.context_schedule = context_schedule
|
self.context_schedule = context_schedule
|
||||||
self.fuse_method = fuse_method
|
self.fuse_method = fuse_method
|
||||||
self.context_length = context_length
|
self.context_length = context_length
|
||||||
@ -162,6 +174,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
self.freenoise = freenoise
|
self.freenoise = freenoise
|
||||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||||
self.split_conds_to_windows = split_conds_to_windows
|
self.split_conds_to_windows = split_conds_to_windows
|
||||||
|
self.causal_window_fix = causal_window_fix
|
||||||
|
|
||||||
self.callbacks = {}
|
self.callbacks = {}
|
||||||
|
|
||||||
@ -318,6 +331,14 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
# allow processing to end between context window executions for faster Cancel
|
# allow processing to end between context window executions for faster Cancel
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
|
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
|
||||||
|
anchor_applied = False
|
||||||
|
if self.causal_window_fix:
|
||||||
|
anchor_idx = window.index_list[0] - 1
|
||||||
|
if 0 <= anchor_idx < x_in.size(self.dim):
|
||||||
|
window.causal_anchor_index = anchor_idx
|
||||||
|
anchor_applied = True
|
||||||
|
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||||
|
|
||||||
@ -332,6 +353,12 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
if device is not None:
|
if device is not None:
|
||||||
for i in range(len(sub_conds_out)):
|
for i in range(len(sub_conds_out)):
|
||||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||||
|
|
||||||
|
# strip causal_window_fix anchor if applied
|
||||||
|
if anchor_applied:
|
||||||
|
for i in range(len(sub_conds_out)):
|
||||||
|
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
|
||||||
|
|
||||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@ -29,6 +29,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||||
|
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||||
@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||||
@ -50,7 +51,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
dim=dim,
|
dim=dim,
|
||||||
freenoise=freenoise,
|
freenoise=freenoise,
|
||||||
cond_retain_index_list=cond_retain_index_list,
|
cond_retain_index_list=cond_retain_index_list,
|
||||||
split_conds_to_windows=split_conds_to_windows
|
split_conds_to_windows=split_conds_to_windows,
|
||||||
|
causal_window_fix=causal_window_fix,
|
||||||
)
|
)
|
||||||
# make memory usage calculation only take into account the context window latents
|
# make memory usage calculation only take into account the context window latents
|
||||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user