mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 01:02:56 +08:00
Merge 75e8e4b6dc into 7bbf1e8169
This commit is contained in:
commit
065eaad039
@ -8,6 +8,8 @@ from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
import comfy.conds
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
@ -51,12 +53,18 @@ class ContextHandlerABC(ABC):
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.context_overlap = context_overlap
|
||||
self.dim = dim
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow}
|
||||
self.guide_frames_indices: list[int] = []
|
||||
self.guide_overlap_info: list[tuple[int, int]] = []
|
||||
self.guide_kf_local_positions: list[int] = []
|
||||
self.guide_downscale_factors: list[int] = []
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||
if dim is None:
|
||||
@ -85,6 +93,11 @@ class IndexListContextWindow(ContextWindowABC):
|
||||
region_idx = int(self.center_ratio * num_regions)
|
||||
return min(max(region_idx, 0), num_regions - 1)
|
||||
|
||||
def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow':
|
||||
if modality_idx == 0:
|
||||
return self
|
||||
return self.modality_windows[modality_idx]
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
@ -148,6 +161,172 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
|
||||
def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]):
|
||||
"""Compute which concatenated guide frames overlap with a context window.
|
||||
|
||||
Each guide's latent-space start is derived from its first token's pixel-t-start
|
||||
in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the
|
||||
model's temporal_downscale_ratio.
|
||||
|
||||
Args:
|
||||
guide_entries: list of guide_attention_entry dicts
|
||||
keyframe_idxs: per-token pixel coords cond tensor for the modality
|
||||
temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio
|
||||
window_index_list: the window's frame indices into the video portion
|
||||
|
||||
Returns:
|
||||
suffix_indices: indices into the guide_frames tensor for frame selection
|
||||
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
|
||||
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
|
||||
total_overlap: total number of overlapping guide frames
|
||||
"""
|
||||
window_set = set(window_index_list)
|
||||
window_list = list(window_index_list)
|
||||
suffix_indices = []
|
||||
overlap_info = []
|
||||
kf_local_positions = []
|
||||
suffix_base = 0
|
||||
token_offset = 0
|
||||
|
||||
for entry_idx, entry in enumerate(guide_entries):
|
||||
first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item())
|
||||
latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio
|
||||
guide_len = entry["latent_shape"][0]
|
||||
entry_overlap = 0
|
||||
|
||||
for local_offset in range(guide_len):
|
||||
video_pos = latent_start + local_offset
|
||||
if video_pos in window_set:
|
||||
suffix_indices.append(suffix_base + local_offset)
|
||||
kf_local_positions.append(window_list.index(video_pos))
|
||||
entry_overlap += 1
|
||||
|
||||
if entry_overlap > 0:
|
||||
overlap_info.append((entry_idx, entry_overlap))
|
||||
suffix_base += guide_len
|
||||
token_offset += entry["pre_filter_count"]
|
||||
|
||||
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WindowingState:
|
||||
"""Per-modality context windowing state for each step,
|
||||
built using IndexListContextHandler._build_window_state().
|
||||
For non-multimodal models the lists are length 1
|
||||
"""
|
||||
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
|
||||
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
|
||||
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
|
||||
keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation
|
||||
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
|
||||
dim: int = 0 # primary modality temporal dim for context windowing
|
||||
is_multimodal: bool = False
|
||||
temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio
|
||||
|
||||
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
|
||||
"""Reformat window for multimodal contexts by deriving per-modality index lists.
|
||||
Non-multimodal contexts return the input window unchanged.
|
||||
"""
|
||||
if not self.is_multimodal:
|
||||
return window
|
||||
|
||||
x = self.latents[0]
|
||||
primary_total = self.latent_shapes[0][self.dim]
|
||||
primary_overlap = window.context_overlap
|
||||
map_shapes = self.latent_shapes
|
||||
if x.size(self.dim) != primary_total:
|
||||
map_shapes = list(self.latent_shapes)
|
||||
video_shape = list(self.latent_shapes[0])
|
||||
video_shape[self.dim] = x.size(self.dim)
|
||||
map_shapes[0] = torch.Size(video_shape)
|
||||
try:
|
||||
per_modality_indices = model.map_context_window_to_modalities(
|
||||
window.index_list, map_shapes, self.dim)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
|
||||
modality_windows = {}
|
||||
for mod_idx in range(1, len(self.latents)):
|
||||
modality_total_frames = self.latents[mod_idx].shape[self.dim]
|
||||
ratio = modality_total_frames / primary_total if primary_total > 0 else 1
|
||||
modality_overlap = max(round(primary_overlap * ratio), 0)
|
||||
modality_windows[mod_idx] = IndexListContextWindow(
|
||||
per_modality_indices[mod_idx], dim=self.dim,
|
||||
total_frames=modality_total_frames,
|
||||
context_overlap=modality_overlap)
|
||||
return IndexListContextWindow(
|
||||
window.index_list, dim=self.dim, total_frames=x.shape[self.dim],
|
||||
modality_windows=modality_windows, context_overlap=primary_overlap)
|
||||
|
||||
def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]:
|
||||
"""Slice latents for a context window, injecting guide frames where applicable.
|
||||
For multimodal contexts, uses the modality-specific windows derived in prepare_window().
|
||||
"""
|
||||
sliced = []
|
||||
guide_frame_counts = []
|
||||
for idx in range(len(self.latents)):
|
||||
modality_window = window.get_window_for_modality(idx)
|
||||
retain = retain_index_list if idx == 0 else []
|
||||
s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain)
|
||||
if self.guide_entries[idx] is not None:
|
||||
s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx)
|
||||
else:
|
||||
ng = 0
|
||||
sliced.append(s)
|
||||
guide_frame_counts.append(ng)
|
||||
return sliced, guide_frame_counts
|
||||
|
||||
def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow):
|
||||
"""Strip injected guide frames from per-cond, per-modality outputs in place."""
|
||||
for idx in range(len(self.latents)):
|
||||
if guide_frame_counts[idx] > 0:
|
||||
window_len = len(window.get_window_for_modality(idx).index_list)
|
||||
for ci in range(len(out_per_modality)):
|
||||
out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len)
|
||||
|
||||
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
|
||||
guide_entries = self.guide_entries[modality_idx]
|
||||
guide_frames = self.guide_latents[modality_idx]
|
||||
keyframe_idxs = self.keyframe_idxs[modality_idx]
|
||||
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(
|
||||
guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list)
|
||||
# Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0.
|
||||
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
kf_local_pos = [p + 1 for p in kf_local_pos]
|
||||
window.guide_frames_indices = suffix_idx
|
||||
window.guide_overlap_info = overlap_info
|
||||
window.guide_kf_local_positions = kf_local_pos
|
||||
|
||||
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
|
||||
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
|
||||
guide_downscale_factors = []
|
||||
if guide_frame_count > 0:
|
||||
full_H = guide_frames.shape[3]
|
||||
for entry_idx, _ in overlap_info:
|
||||
entry_H = guide_entries[entry_idx]["latent_shape"][1]
|
||||
guide_downscale_factors.append(full_H // entry_H)
|
||||
window.guide_downscale_factors = guide_downscale_factors
|
||||
|
||||
if guide_frame_count > 0:
|
||||
idx = tuple([slice(None)] * self.dim + [suffix_idx])
|
||||
return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count
|
||||
return latent_slice, 0
|
||||
|
||||
def patch_latent_shapes(self, sub_conds, new_shapes):
|
||||
if not self.is_multimodal:
|
||||
return
|
||||
|
||||
for cond_list in sub_conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
if 'latent_shapes' in model_conds:
|
||||
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
@ -162,7 +341,7 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
|
||||
causal_window_fix: bool=True):
|
||||
latent_retain_index_list: list[int]=[], causal_window_fix: bool=True):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
@ -174,17 +353,118 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else []
|
||||
self.causal_window_fix = causal_window_fix
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
@staticmethod
|
||||
def _get_latent_shapes(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
if 'latent_shapes' in model_conds:
|
||||
return model_conds['latent_shapes'].cond
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_guide_entries(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
entries = model_conds.get('guide_attention_entries')
|
||||
if entries is not None and hasattr(entries, 'cond') and entries.cond:
|
||||
return entries.cond
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_keyframe_idxs(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
kf = model_conds.get('keyframe_idxs')
|
||||
if kf is not None and hasattr(kf, 'cond') and kf.cond is not None:
|
||||
return kf.cond
|
||||
return None
|
||||
|
||||
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
|
||||
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
|
||||
If guide frames are present on the primary modality, only the video portion is shuffled.
|
||||
"""
|
||||
guide_entries = self._get_guide_entries(conds)
|
||||
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
|
||||
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
if latent_shapes is not None and len(latent_shapes) > 1:
|
||||
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
|
||||
primary_total = latent_shapes[0][self.dim]
|
||||
primary_video_count = modalities[0].size(self.dim) - guide_count
|
||||
apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed)
|
||||
for i in range(1, len(modalities)):
|
||||
mod_total = latent_shapes[i][self.dim]
|
||||
ratio = mod_total / primary_total if primary_total > 0 else 1
|
||||
mod_ctx_len = max(round(self.context_length * ratio), 1)
|
||||
mod_ctx_overlap = max(round(self.context_overlap * ratio), 0)
|
||||
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
|
||||
noise, _ = comfy.utils.pack_latents(modalities)
|
||||
return noise
|
||||
video_count = noise.size(self.dim) - guide_count
|
||||
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
|
||||
return noise
|
||||
|
||||
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState:
|
||||
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
|
||||
unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in]
|
||||
|
||||
unpacked_latents_list = list(unpacked_latents)
|
||||
guide_latents_list = [None] * len(unpacked_latents)
|
||||
guide_entries_list = [None] * len(unpacked_latents)
|
||||
keyframe_idxs_list = [None] * len(unpacked_latents)
|
||||
|
||||
extracted_guide_entries = self._get_guide_entries(conds)
|
||||
extracted_keyframe_idxs = self._get_keyframe_idxs(conds)
|
||||
|
||||
# Strip guide frames (only from first modality for now)
|
||||
if extracted_guide_entries is not None:
|
||||
guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries)
|
||||
if guide_count > 0:
|
||||
x = unpacked_latents[0]
|
||||
latent_count = x.size(self.dim) - guide_count
|
||||
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
|
||||
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
|
||||
guide_entries_list[0] = extracted_guide_entries
|
||||
keyframe_idxs_list[0] = extracted_keyframe_idxs
|
||||
|
||||
|
||||
return WindowingState(
|
||||
latents=unpacked_latents_list,
|
||||
guide_latents=guide_latents_list,
|
||||
guide_entries=guide_entries_list,
|
||||
keyframe_idxs=keyframe_idxs_list,
|
||||
latent_shapes=latent_shapes,
|
||||
dim=self.dim,
|
||||
is_multimodal=is_multimodal,
|
||||
temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio)
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||
window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute
|
||||
total_frame_count = window_state.latents[0].size(self.dim)
|
||||
if total_frame_count > self.context_length:
|
||||
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
if self.latent_retain_index_list:
|
||||
logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}")
|
||||
return True
|
||||
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
|
||||
return False
|
||||
|
||||
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||
@ -275,7 +555,9 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
|
||||
current_timestep = timestep[0].to(sample_sigmas.dtype)
|
||||
mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
return # substep from multi-step sampler: keep self._step from the last full step
|
||||
@ -284,54 +566,98 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self._model = model
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||
window_state = self._build_window_state(x_in, conds, model)
|
||||
num_modalities = len(window_state.latents)
|
||||
|
||||
context_windows = self.get_context_windows(model, window_state.latents[0], model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
total_windows = len(enumerated_context_windows)
|
||||
|
||||
# Initialize per-modality accumulators (length 1 for single-modality)
|
||||
accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
|
||||
else:
|
||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
|
||||
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
# accumulate results from each context window
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||
results = self.evaluate_context_windows(
|
||||
calc_cond_batch, model, x_in, conds, timestep, [enum_window],
|
||||
model_options, window_state=window_state, total_windows=total_windows)
|
||||
for result in results:
|
||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||
conds_final, counts_final, biases_final)
|
||||
# result.sub_conds_out is per-cond, per-modality: list[list[Tensor]]
|
||||
for mod_idx in range(num_modalities):
|
||||
mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))]
|
||||
modality_window = result.window.get_window_for_modality(mod_idx)
|
||||
self.combine_context_window_results(
|
||||
window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window,
|
||||
result.window_idx, total_windows, timestep,
|
||||
accum[mod_idx], counts[mod_idx], biases[mod_idx])
|
||||
|
||||
# fuse accumulated results into final conds
|
||||
try:
|
||||
# finalize conds
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
# relative is already normalized, so return as is
|
||||
del counts_final
|
||||
return conds_final
|
||||
else:
|
||||
# normalize conds via division by context usage counts
|
||||
for i in range(len(conds_final)):
|
||||
conds_final[i] /= counts_final[i]
|
||||
del counts_final
|
||||
return conds_final
|
||||
result_out = []
|
||||
for ci in range(len(conds)):
|
||||
finalized = []
|
||||
for mod_idx in range(num_modalities):
|
||||
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
||||
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
||||
f = accum[mod_idx][ci]
|
||||
|
||||
# if guide frames were injected, append them to the end of the fused latents for the next step
|
||||
if window_state.guide_latents[mod_idx] is not None:
|
||||
f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim)
|
||||
finalized.append(f)
|
||||
|
||||
# pack modalities together if needed
|
||||
if window_state.is_multimodal and len(finalized) > 1:
|
||||
packed, _ = comfy.utils.pack_latents(finalized)
|
||||
else:
|
||||
packed = finalized[0]
|
||||
|
||||
result_out.append(packed)
|
||||
return result_out
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, device=None, first_device=None):
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds,
|
||||
timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, window_state: WindowingState, total_windows: int = None,
|
||||
device=None, first_device=None):
|
||||
"""Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out
|
||||
|
||||
For each window:
|
||||
1. Builds windows (for each modality if multimodal)
|
||||
2. Slices window for each modality
|
||||
3. Injects concatenated latent guide frames where present
|
||||
4. Packs together if needed and calls model
|
||||
5. Unpacks and strips any guides from outputs
|
||||
"""
|
||||
x = window_state.latents[0]
|
||||
|
||||
results: list[ContextResults] = []
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
|
||||
# prepare the window accounting for multimodal windows
|
||||
window = window_state.prepare_window(window, model)
|
||||
|
||||
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward.
|
||||
# Set anchor before slice_for_window so the latent slice and downstream cond slices both pick it up.
|
||||
anchor_applied = False
|
||||
if self.causal_window_fix:
|
||||
anchor_idx = window.index_list[0] - 1
|
||||
@ -339,27 +665,46 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
window.causal_anchor_index = anchor_idx
|
||||
anchor_applied = True
|
||||
|
||||
# slice the window for each modality, injecting guide frames where applicable
|
||||
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device)
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
# update exposed params
|
||||
logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}"
|
||||
+ (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "")
|
||||
)
|
||||
|
||||
# if multimodal, pack modalities together
|
||||
if window_state.is_multimodal and len(sliced) > 1:
|
||||
sub_x, sub_shapes = comfy.utils.pack_latents(sliced)
|
||||
else:
|
||||
sub_x, sub_shapes = sliced[0], [sliced[0].shape]
|
||||
|
||||
# get resized conds for window
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
# get subsections of x, timestep, conds
|
||||
sub_x = window.get_tensor(x_in, device)
|
||||
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
||||
sub_timestep = window.get_tensor(timestep, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds]
|
||||
|
||||
# if multimodal, patch latent_shapes in conds for correct unpacking in model
|
||||
window_state.patch_latent_shapes(sub_conds, sub_shapes)
|
||||
|
||||
# call model on window
|
||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
|
||||
# strip causal_window_fix anchor if applied
|
||||
# unpack outputs
|
||||
out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
|
||||
|
||||
# strip causal_window_fix anchor from primary modality before guide strip so window_len math stays correct
|
||||
if anchor_applied:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
|
||||
for ci in range(len(out_per_modality)):
|
||||
t = out_per_modality[ci][0]
|
||||
out_per_modality[ci][0] = t.narrow(self.dim, 1, t.shape[self.dim] - 1)
|
||||
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
# strip injected guide frames
|
||||
window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window)
|
||||
|
||||
results.append(ContextResults(window_idx, out_per_modality, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
@ -383,7 +728,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
else:
|
||||
# add conds and counts based on weights of fuse method
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap)
|
||||
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||
for i in range(len(sub_conds_out)):
|
||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||
@ -393,16 +738,22 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||
|
||||
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
||||
# limit noise_shape length to context_length for more accurate vram use estimation
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs):
|
||||
# Scale noise_shape to a single context window so VRAM estimation budgets per-window.
|
||||
model_options = kwargs.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is not None:
|
||||
noise_shape = list(noise_shape)
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, *args, **kwargs)
|
||||
is_packed = len(noise_shape) == 3 and noise_shape[1] == 1
|
||||
if is_packed:
|
||||
# TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a
|
||||
# per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM.
|
||||
pass
|
||||
elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, conds, *args, **kwargs)
|
||||
|
||||
|
||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
@ -422,11 +773,12 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
|
||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
if not handler.freenoise:
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||
|
||||
conds = [guider.conds.get('positive', guider.conds.get('negative', []))]
|
||||
noise = handler._apply_freenoise(noise, conds, extra_args["seed"])
|
||||
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
|
||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||
@ -434,7 +786,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
_sampler_sample_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
@ -580,8 +931,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||
return ContextSchedule(context_schedule, func)
|
||||
|
||||
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None):
|
||||
context_overlap = handler.context_overlap if context_overlap is None else context_overlap
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap)
|
||||
|
||||
|
||||
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||
@ -599,18 +951,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||
return weight_sequence
|
||||
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs):
|
||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||
# only expected overlap is given different weights
|
||||
weights_torch = torch.ones((length))
|
||||
# blend left-side on all except first window
|
||||
if min(idxs) > 0:
|
||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||
weights_torch[:handler.context_overlap] = ramp_up
|
||||
ramp_up = torch.linspace(1e-37, 1, context_overlap)
|
||||
weights_torch[:context_overlap] = ramp_up
|
||||
# blend right-side on all except last window
|
||||
if max(idxs) < full_length-1:
|
||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||
weights_torch[-handler.context_overlap:] = ramp_down
|
||||
ramp_down = torch.linspace(1, 1e-37, context_overlap)
|
||||
weights_torch[-context_overlap:] = ramp_down
|
||||
return weights_torch
|
||||
|
||||
class ContextFuseMethods:
|
||||
|
||||
@ -1028,7 +1028,7 @@ class LTXVModel(LTXBaseModel):
|
||||
)
|
||||
|
||||
grid_mask = None
|
||||
if keyframe_idxs is not None:
|
||||
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
|
||||
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||
@ -1315,7 +1315,7 @@ class LTXVModel(LTXBaseModel):
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
if keyframe_idxs is not None:
|
||||
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
|
||||
grid_mask = kwargs["grid_mask"]
|
||||
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
||||
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
||||
|
||||
@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import comfy.ldm.lightricks.av_model
|
||||
import comfy.ldm.lightricks.symmetric_patchifier
|
||||
import comfy.context_windows
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
@ -1094,6 +1095,127 @@ class LTXAV(BaseModel):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
|
||||
result = [primary_indices]
|
||||
if len(latent_shapes) < 2:
|
||||
return result
|
||||
|
||||
video_total = latent_shapes[0][dim]
|
||||
|
||||
for i in range(1, len(latent_shapes)):
|
||||
mod_total = latent_shapes[i][dim]
|
||||
# Map each primary index to its proportional range of modality indices and
|
||||
# concatenate in order. Preserves wrapped/strided geometry so the modality
|
||||
# attends to the same temporal regions as the primary window.
|
||||
mod_indices = []
|
||||
seen = set()
|
||||
for v_idx in primary_indices:
|
||||
a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1)
|
||||
a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total)
|
||||
if a_end <= a_start:
|
||||
a_end = a_start + 1
|
||||
for a in range(a_start, a_end):
|
||||
if a not in seen:
|
||||
seen.add(a)
|
||||
mod_indices.append(a)
|
||||
result.append(mod_indices)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_guide_entries(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
entries = model_conds.get('guide_attention_entries')
|
||||
if entries is not None and hasattr(entries, 'cond') and entries.cond:
|
||||
return entries.cond
|
||||
return None
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
# Audio denoise mask — slice using audio modality window
|
||||
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
|
||||
audio_window = window.modality_windows.get(1)
|
||||
if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
# Video denoise mask — split into video + guide portions, slice each
|
||||
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
cond_tensor = cond_value.cond
|
||||
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
|
||||
if guide_count > 0:
|
||||
T_video = x_in.size(window.dim)
|
||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||
suffix_indices = window.guide_frames_indices
|
||||
if suffix_indices:
|
||||
idx = tuple([slice(None)] * window.dim + [suffix_indices])
|
||||
sliced_guide = guide_mask[idx].to(device)
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
else:
|
||||
return cond_value._copy_with(sliced_video)
|
||||
|
||||
# Keyframe indices — regenerate pixel coords for window, select guide positions
|
||||
if cond_key == "keyframe_idxs":
|
||||
kf_local_pos = window.guide_kf_local_positions
|
||||
if not kf_local_pos:
|
||||
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
window_len = len(window.index_list)
|
||||
# account for causal_window_fix anchor in coord space size
|
||||
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
window_len += 1
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
scale_factors = self.diffusion_model.vae_scale_factors
|
||||
pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
tokens = []
|
||||
for pos in kf_local_pos:
|
||||
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
|
||||
pixel_coords = pixel_coords[:, :, tokens, :]
|
||||
|
||||
# Adjust spatial end positions for dilated (downscaled) guides.
|
||||
# Each guide entry may have a different downscale factor; expand the
|
||||
# per-entry factor to cover all tokens belonging to that entry.
|
||||
downscale_factors = window.guide_downscale_factors
|
||||
overlap_info = window.guide_overlap_info
|
||||
if downscale_factors:
|
||||
per_token_factor = []
|
||||
for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors):
|
||||
per_token_factor.extend([dsf] * (overlap_count * H * W))
|
||||
factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype)
|
||||
spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor(
|
||||
scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype,
|
||||
).view(1, -1, 1, 1)
|
||||
pixel_coords[:, 1:, :, 1:] += spatial_end_offset
|
||||
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
|
||||
# Guide attention entries — adjust per-guide counts based on window overlap
|
||||
if cond_key == "guide_attention_entries":
|
||||
overlap_info = window.guide_overlap_info
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
new_entries = []
|
||||
for entry_idx, overlap_count in overlap_info:
|
||||
e = cond_value.cond[entry_idx]
|
||||
new_entries.append({**e,
|
||||
"pre_filter_count": overlap_count * H * W,
|
||||
"latent_shape": [overlap_count, H, W]})
|
||||
return cond_value._copy_with(new_entries)
|
||||
|
||||
return None
|
||||
|
||||
class HunyuanVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
|
||||
@ -14,21 +14,22 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
description="Manually set context windows.",
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True),
|
||||
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True),
|
||||
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
|
||||
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
|
||||
io.Combo.Input("context_schedule", options=[
|
||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||
comfy.context_windows.ContextSchedules.BATCHED,
|
||||
], tooltip="The stride of the context window."),
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
||||
], default=comfy.context_windows.ContextSchedules.STATIC_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."),
|
||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."),
|
||||
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
|
||||
],
|
||||
outputs=[
|
||||
@ -39,7 +40,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model:
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[], causal_window_fix: bool=True) -> io.Model:
|
||||
model = model.clone()
|
||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||
@ -52,6 +53,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
freenoise=freenoise,
|
||||
cond_retain_index_list=cond_retain_index_list,
|
||||
split_conds_to_windows=split_conds_to_windows,
|
||||
latent_retain_index_list=latent_retain_index_list,
|
||||
causal_window_fix=causal_window_fix,
|
||||
)
|
||||
# make memory usage calculation only take into account the context window latents
|
||||
@ -65,33 +67,70 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
schema = super().define_schema()
|
||||
schema.node_id = "WanContextWindowsManual"
|
||||
schema.display_name = "WAN Context Windows (Manual)"
|
||||
schema.description = "Manually set context windows for WAN-like models (dim=2)."
|
||||
schema.display_name = "Wan Context Windows"
|
||||
schema.description = "Set context windows for Wan-like models."
|
||||
schema.inputs = [
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True),
|
||||
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True),
|
||||
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window in real frames. Must be 4*n + 1."),
|
||||
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window in real frames."),
|
||||
io.Combo.Input("context_schedule", options=[
|
||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||
comfy.context_windows.ContextSchedules.BATCHED,
|
||||
], tooltip="The stride of the context window."),
|
||||
], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True),
|
||||
io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first I2V frame in every context window (may help retain initial reference)."),
|
||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True),
|
||||
]
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||
retain_first_frame: bool=False, split_conds_to_windows: bool=False) -> io.Model:
|
||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||
context_overlap = max(context_overlap // 4, 0) # at least overlap 0
|
||||
retain_index_list = "0" if retain_first_frame else ""
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||
|
||||
|
||||
class LTXVContextWindowsNode(ContextWindowsManualNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
schema = super().define_schema()
|
||||
schema.node_id = "LTXVContextWindows"
|
||||
schema.display_name = "LTXV Context Windows"
|
||||
schema.description = "Set context windows for LTXV-like models."
|
||||
schema.inputs = [
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=8, default=145, tooltip="The length of the context window in real frames. Must be 8*n + 1."),
|
||||
io.Int.Input("context_overlap", min=0, step=8, default=40, tooltip="The overlap of the context window in real frames."),
|
||||
io.Combo.Input("context_schedule", options=[
|
||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||
comfy.context_windows.ContextSchedules.BATCHED,
|
||||
], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True),
|
||||
io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first latent frame in every context window (may help retain initial reference)."),
|
||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True),
|
||||
]
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, fuse_method: str, freenoise: bool,
|
||||
retain_first_frame: bool=False, split_conds_to_windows: bool=False, context_stride: int=1, closed_loop: bool=False) -> io.Model:
|
||||
context_length = max(((context_length - 1) // 8) + 1, 1) # at least length 1
|
||||
context_overlap = max(context_overlap // 8, 0) # at least overlap 0
|
||||
retain_index_list = "0" if retain_first_frame else ""
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise,
|
||||
cond_retain_index_list=retain_index_list, latent_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||
|
||||
|
||||
class ContextWindowsExtension(ComfyExtension):
|
||||
@ -99,6 +138,7 @@ class ContextWindowsExtension(ComfyExtension):
|
||||
return [
|
||||
ContextWindowsManualNode,
|
||||
WanContextWindowsManualNode,
|
||||
LTXVContextWindowsNode,
|
||||
]
|
||||
|
||||
def comfy_entrypoint():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user