mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 12:50:18 +08:00
Add option to retain condition by indexes for each window
This allows for example Wan/HunyuanVideo image to video to "work" by using the initial start frame for each window, otherwise windows beyond first will be pure T2V generations.
This commit is contained in:
parent
ac93376ef8
commit
e2269b4208
@ -56,13 +56,17 @@ class IndexListContextWindow(ContextWindowABC):
|
|||||||
self.context_length = len(index_list)
|
self.context_length = len(index_list)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
||||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
dim = self.dim
|
dim = self.dim
|
||||||
if dim == 0 and full.shape[dim] == 1:
|
if dim == 0 and full.shape[dim] == 1:
|
||||||
return full
|
return full
|
||||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||||
return full[idx].to(device)
|
window = full[idx]
|
||||||
|
if retain_index_list:
|
||||||
|
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||||
|
window[idx] = full[idx]
|
||||||
|
return window.to(device)
|
||||||
|
|
||||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
@ -94,7 +98,8 @@ class ContextFuseMethod:
|
|||||||
|
|
||||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||||
class IndexListContextHandler(ContextHandlerABC):
|
class IndexListContextHandler(ContextHandlerABC):
|
||||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop: bool=False, dim:int=0, freenoise: bool=False):
|
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||||
|
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[]):
|
||||||
self.context_schedule = context_schedule
|
self.context_schedule = context_schedule
|
||||||
self.fuse_method = fuse_method
|
self.fuse_method = fuse_method
|
||||||
self.context_length = context_length
|
self.context_length = context_length
|
||||||
@ -104,6 +109,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
self.dim = dim
|
self.dim = dim
|
||||||
self._step = 0
|
self._step = 0
|
||||||
self.freenoise = freenoise
|
self.freenoise = freenoise
|
||||||
|
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||||
|
|
||||||
self.callbacks = {}
|
self.callbacks = {}
|
||||||
|
|
||||||
@ -111,6 +117,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||||
if x_in.size(self.dim) > self.context_length:
|
if x_in.size(self.dim) > self.context_length:
|
||||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||||
|
if self.cond_retain_index_list:
|
||||||
|
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -154,7 +162,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||||
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
|
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
|
||||||
elif cond_key == "num_video_frames": # for SVD
|
elif cond_key == "num_video_frames": # for SVD
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||||
new_cond_item[cond_key].cond = window.context_length
|
new_cond_item[cond_key].cond = window.context_length
|
||||||
|
|||||||
@ -27,6 +27,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
|
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||||
@ -35,7 +36,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool) -> io.Model:
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, cond_retain_index_list: list[int]=[]) -> io.Model:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||||
@ -45,7 +46,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
context_stride=context_stride,
|
context_stride=context_stride,
|
||||||
closed_loop=closed_loop,
|
closed_loop=closed_loop,
|
||||||
dim=dim,
|
dim=dim,
|
||||||
freenoise=freenoise)
|
freenoise=freenoise,
|
||||||
|
cond_retain_index_list=cond_retain_index_list)
|
||||||
# make memory usage calculation only take into account the context window latents
|
# make memory usage calculation only take into account the context window latents
|
||||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||||
if freenoise: # no other use for this wrapper at this time
|
if freenoise: # no other use for this wrapper at this time
|
||||||
@ -73,14 +75,15 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
|
|||||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
|
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||||
]
|
]
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool) -> io.Model:
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, cond_retain_index_list: list[int]=[]) -> io.Model:
|
||||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise)
|
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list)
|
||||||
|
|
||||||
|
|
||||||
class ContextWindowsExtension(ComfyExtension):
|
class ContextWindowsExtension(ComfyExtension):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user