mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 04:10:15 +08:00
Allow splitting multiple conds into different windows
This commit is contained in:
parent
dcf721b772
commit
a8aae35afd
@ -51,10 +51,12 @@ class ContextHandlerABC(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class IndexListContextWindow(ContextWindowABC):
|
class IndexListContextWindow(ContextWindowABC):
|
||||||
def __init__(self, index_list: list[int], dim: int=0):
|
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||||
self.index_list = index_list
|
self.index_list = index_list
|
||||||
self.context_length = len(index_list)
|
self.context_length = len(index_list)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
self.total_frames = total_frames
|
||||||
|
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||||
|
|
||||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
@ -75,6 +77,10 @@ class IndexListContextWindow(ContextWindowABC):
|
|||||||
full[idx] += to_add
|
full[idx] += to_add
|
||||||
return full
|
return full
|
||||||
|
|
||||||
|
def get_region_index(self, num_regions: int) -> int:
|
||||||
|
region_idx = int(self.center_ratio * num_regions)
|
||||||
|
return min(max(region_idx, 0), num_regions - 1)
|
||||||
|
|
||||||
|
|
||||||
class IndexListCallbacks:
|
class IndexListCallbacks:
|
||||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||||
@ -99,7 +105,7 @@ class ContextFuseMethod:
|
|||||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||||
class IndexListContextHandler(ContextHandlerABC):
|
class IndexListContextHandler(ContextHandlerABC):
|
||||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[]):
|
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||||
self.context_schedule = context_schedule
|
self.context_schedule = context_schedule
|
||||||
self.fuse_method = fuse_method
|
self.fuse_method = fuse_method
|
||||||
self.context_length = context_length
|
self.context_length = context_length
|
||||||
@ -110,6 +116,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
self._step = 0
|
self._step = 0
|
||||||
self.freenoise = freenoise
|
self.freenoise = freenoise
|
||||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||||
|
self.split_conds_to_windows = split_conds_to_windows
|
||||||
|
|
||||||
self.callbacks = {}
|
self.callbacks = {}
|
||||||
|
|
||||||
@ -132,6 +139,11 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
return None
|
return None
|
||||||
# reuse or resize cond items to match context requirements
|
# reuse or resize cond items to match context requirements
|
||||||
resized_cond = []
|
resized_cond = []
|
||||||
|
# if multiple conds, split based on primary region
|
||||||
|
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||||
|
region = window.get_region_index(len(cond_in))
|
||||||
|
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
|
||||||
|
cond_in = [cond_in[region]]
|
||||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||||
for actual_cond in cond_in:
|
for actual_cond in cond_in:
|
||||||
resized_actual_cond = actual_cond.copy()
|
resized_actual_cond = actual_cond.copy()
|
||||||
@ -184,7 +196,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||||
return context_windows
|
return context_windows
|
||||||
|
|
||||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
|
|||||||
@ -28,6 +28,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||||
|
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||||
@ -36,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, cond_retain_index_list: list[int]=[]) -> io.Model:
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||||
|
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||||
@ -47,7 +49,9 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
closed_loop=closed_loop,
|
closed_loop=closed_loop,
|
||||||
dim=dim,
|
dim=dim,
|
||||||
freenoise=freenoise,
|
freenoise=freenoise,
|
||||||
cond_retain_index_list=cond_retain_index_list)
|
cond_retain_index_list=cond_retain_index_list,
|
||||||
|
split_conds_to_windows=split_conds_to_windows
|
||||||
|
)
|
||||||
# make memory usage calculation only take into account the context window latents
|
# make memory usage calculation only take into account the context window latents
|
||||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||||
if freenoise: # no other use for this wrapper at this time
|
if freenoise: # no other use for this wrapper at this time
|
||||||
@ -76,14 +80,16 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
|
|||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||||
|
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||||
]
|
]
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, cond_retain_index_list: list[int]=[]) -> io.Model:
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
|
||||||
|
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list)
|
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||||
|
|
||||||
|
|
||||||
class ContextWindowsExtension(ComfyExtension):
|
class ContextWindowsExtension(ComfyExtension):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user