mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 09:42:29 +08:00
Context windows fixes and features (#10975)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* Apply cond slice fix * Add FreeNoise * Update context_windows.py * Add option to retain condition by indexes for each window This allows for example Wan/HunyuanVideo image to video to "work" by using the initial start frame for each window, otherwise windows beyond first will be pure T2V generations. * Update context_windows.py * Allow splitting multiple conds into different windows * Add handling for audio_embed * whitespace * Allow freenoise to work on other dims, handle 4D batch timestep Refactor Freenoise function. And fix batch handling as timesteps seem to be expanded to batch size now. * Disable experimental options for now So that the Freenoise and bugfixes can be merged first --------- Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com> Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
This commit is contained in:
parent
6fd463aec9
commit
79d17ba233
@ -51,26 +51,36 @@ class ContextHandlerABC(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class IndexListContextWindow(ContextWindowABC):
|
class IndexListContextWindow(ContextWindowABC):
|
||||||
def __init__(self, index_list: list[int], dim: int=0):
|
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||||
self.index_list = index_list
|
self.index_list = index_list
|
||||||
self.context_length = len(index_list)
|
self.context_length = len(index_list)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
self.total_frames = total_frames
|
||||||
|
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||||
|
|
||||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
dim = self.dim
|
dim = self.dim
|
||||||
if dim == 0 and full.shape[dim] == 1:
|
if dim == 0 and full.shape[dim] == 1:
|
||||||
return full
|
return full
|
||||||
idx = [slice(None)] * dim + [self.index_list]
|
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||||
return full[idx].to(device)
|
window = full[idx]
|
||||||
|
if retain_index_list:
|
||||||
|
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||||
|
window[idx] = full[idx]
|
||||||
|
return window.to(device)
|
||||||
|
|
||||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
dim = self.dim
|
dim = self.dim
|
||||||
idx = [slice(None)] * dim + [self.index_list]
|
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||||
full[idx] += to_add
|
full[idx] += to_add
|
||||||
return full
|
return full
|
||||||
|
|
||||||
|
def get_region_index(self, num_regions: int) -> int:
|
||||||
|
region_idx = int(self.center_ratio * num_regions)
|
||||||
|
return min(max(region_idx, 0), num_regions - 1)
|
||||||
|
|
||||||
|
|
||||||
class IndexListCallbacks:
|
class IndexListCallbacks:
|
||||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||||
@ -94,7 +104,8 @@ class ContextFuseMethod:
|
|||||||
|
|
||||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||||
class IndexListContextHandler(ContextHandlerABC):
|
class IndexListContextHandler(ContextHandlerABC):
|
||||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||||
|
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||||
self.context_schedule = context_schedule
|
self.context_schedule = context_schedule
|
||||||
self.fuse_method = fuse_method
|
self.fuse_method = fuse_method
|
||||||
self.context_length = context_length
|
self.context_length = context_length
|
||||||
@ -103,13 +114,18 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
self.closed_loop = closed_loop
|
self.closed_loop = closed_loop
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self._step = 0
|
self._step = 0
|
||||||
|
self.freenoise = freenoise
|
||||||
|
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||||
|
self.split_conds_to_windows = split_conds_to_windows
|
||||||
|
|
||||||
self.callbacks = {}
|
self.callbacks = {}
|
||||||
|
|
||||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||||
if x_in.size(self.dim) > self.context_length:
|
if x_in.size(self.dim) > self.context_length:
|
||||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||||
|
if self.cond_retain_index_list:
|
||||||
|
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -123,6 +139,11 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
return None
|
return None
|
||||||
# reuse or resize cond items to match context requirements
|
# reuse or resize cond items to match context requirements
|
||||||
resized_cond = []
|
resized_cond = []
|
||||||
|
# if multiple conds, split based on primary region
|
||||||
|
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||||
|
region = window.get_region_index(len(cond_in))
|
||||||
|
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
|
||||||
|
cond_in = [cond_in[region]]
|
||||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||||
for actual_cond in cond_in:
|
for actual_cond in cond_in:
|
||||||
resized_actual_cond = actual_cond.copy()
|
resized_actual_cond = actual_cond.copy()
|
||||||
@ -146,12 +167,19 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||||
for cond_key, cond_value in new_cond_item.items():
|
for cond_key, cond_value in new_cond_item.items():
|
||||||
if isinstance(cond_value, torch.Tensor):
|
if isinstance(cond_value, torch.Tensor):
|
||||||
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||||
|
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||||
|
# Handle audio_embed (temporal dim is 1)
|
||||||
|
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
|
audio_cond = cond_value.cond
|
||||||
|
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||||
# if has cond that is a Tensor, check if needs to be subset
|
# if has cond that is a Tensor, check if needs to be subset
|
||||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
|
||||||
elif cond_key == "num_video_frames": # for SVD
|
elif cond_key == "num_video_frames": # for SVD
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||||
new_cond_item[cond_key].cond = window.context_length
|
new_cond_item[cond_key].cond = window.context_length
|
||||||
@ -164,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
return resized_cond
|
return resized_cond
|
||||||
|
|
||||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||||
matches = torch.nonzero(mask)
|
matches = torch.nonzero(mask)
|
||||||
if torch.numel(matches) == 0:
|
if torch.numel(matches) == 0:
|
||||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||||
@ -173,7 +201,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||||
return context_windows
|
return context_windows
|
||||||
|
|
||||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
@ -250,8 +278,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
prev_weight = (bias_total / (bias_total + bias))
|
prev_weight = (bias_total / (bias_total + bias))
|
||||||
new_weight = (bias / (bias_total + bias))
|
new_weight = (bias / (bias_total + bias))
|
||||||
# account for dims of tensors
|
# account for dims of tensors
|
||||||
idx_window = [slice(None)] * self.dim + [idx]
|
idx_window = tuple([slice(None)] * self.dim + [idx])
|
||||||
pos_window = [slice(None)] * self.dim + [pos]
|
pos_window = tuple([slice(None)] * self.dim + [pos])
|
||||||
# apply new values
|
# apply new values
|
||||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||||
biases_final[i][idx] = bias_total + bias
|
biases_final[i][idx] = bias_total + bias
|
||||||
@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
|
||||||
|
model_options = extra_args.get("model_options", None)
|
||||||
|
if model_options is None:
|
||||||
|
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||||
|
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||||
|
if handler is None:
|
||||||
|
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||||
|
if not handler.freenoise:
|
||||||
|
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||||
|
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||||
|
|
||||||
|
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||||
|
model.add_wrapper_with_key(
|
||||||
|
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||||
|
"ContextWindows_sampler_sample",
|
||||||
|
_sampler_sample_wrapper
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||||
total_dims = len(x_in.shape)
|
total_dims = len(x_in.shape)
|
||||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||||
@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
|
|||||||
for i in range(len(window)):
|
for i in range(len(window)):
|
||||||
# 2) add end_delta to each val to slide windows to end
|
# 2) add end_delta to each val to slide windows to end
|
||||||
window[i] = window[i] + end_delta
|
window[i] = window[i] + end_delta
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
|
||||||
|
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
|
||||||
|
logging.info("Context windows: Applying FreeNoise")
|
||||||
|
generator = torch.Generator(device='cpu').manual_seed(seed)
|
||||||
|
latent_video_length = noise.shape[dim]
|
||||||
|
delta = context_length - context_overlap
|
||||||
|
|
||||||
|
for start_idx in range(0, latent_video_length - context_length, delta):
|
||||||
|
place_idx = start_idx + context_length
|
||||||
|
|
||||||
|
actual_delta = min(delta, latent_video_length - place_idx)
|
||||||
|
if actual_delta <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
|
||||||
|
|
||||||
|
source_slice = [slice(None)] * noise.ndim
|
||||||
|
source_slice[dim] = list_idx
|
||||||
|
target_slice = [slice(None)] * noise.ndim
|
||||||
|
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
|
||||||
|
|
||||||
|
noise[tuple(target_slice)] = noise[tuple(source_slice)]
|
||||||
|
|
||||||
|
return noise
|
||||||
|
|||||||
@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||||
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
|
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||||
|
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||||
@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||||
|
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||||
@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
context_overlap=context_overlap,
|
context_overlap=context_overlap,
|
||||||
context_stride=context_stride,
|
context_stride=context_stride,
|
||||||
closed_loop=closed_loop,
|
closed_loop=closed_loop,
|
||||||
dim=dim)
|
dim=dim,
|
||||||
|
freenoise=freenoise,
|
||||||
|
cond_retain_index_list=cond_retain_index_list,
|
||||||
|
split_conds_to_windows=split_conds_to_windows
|
||||||
|
)
|
||||||
# make memory usage calculation only take into account the context window latents
|
# make memory usage calculation only take into account the context window latents
|
||||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||||
|
if freenoise: # no other use for this wrapper at this time
|
||||||
|
comfy.context_windows.create_sampler_sample_wrapper(model)
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||||
@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
|
|||||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
|
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||||
|
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||||
]
|
]
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
|
||||||
|
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
|
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||||
|
|
||||||
|
|
||||||
class ContextWindowsExtension(ComfyExtension):
|
class ContextWindowsExtension(ComfyExtension):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user