mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-13 10:42:59 +08:00
Merge 75e8e4b6dc into c011fb520c
This commit is contained in:
commit
6274df0c61
@ -8,6 +8,8 @@ from abc import ABC, abstractmethod
|
|||||||
import logging
|
import logging
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
|
import comfy.utils
|
||||||
|
import comfy.conds
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.model_base import BaseModel
|
from comfy.model_base import BaseModel
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
@ -51,12 +53,18 @@ class ContextHandlerABC(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class IndexListContextWindow(ContextWindowABC):
|
class IndexListContextWindow(ContextWindowABC):
|
||||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0):
|
||||||
self.index_list = index_list
|
self.index_list = index_list
|
||||||
self.context_length = len(index_list)
|
self.context_length = len(index_list)
|
||||||
|
self.context_overlap = context_overlap
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.total_frames = total_frames
|
self.total_frames = total_frames
|
||||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||||
|
self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow}
|
||||||
|
self.guide_frames_indices: list[int] = []
|
||||||
|
self.guide_overlap_info: list[tuple[int, int]] = []
|
||||||
|
self.guide_kf_local_positions: list[int] = []
|
||||||
|
self.guide_downscale_factors: list[int] = []
|
||||||
|
|
||||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
@ -85,6 +93,11 @@ class IndexListContextWindow(ContextWindowABC):
|
|||||||
region_idx = int(self.center_ratio * num_regions)
|
region_idx = int(self.center_ratio * num_regions)
|
||||||
return min(max(region_idx, 0), num_regions - 1)
|
return min(max(region_idx, 0), num_regions - 1)
|
||||||
|
|
||||||
|
def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow':
|
||||||
|
if modality_idx == 0:
|
||||||
|
return self
|
||||||
|
return self.modality_windows[modality_idx]
|
||||||
|
|
||||||
|
|
||||||
class IndexListCallbacks:
|
class IndexListCallbacks:
|
||||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||||
@ -148,6 +161,172 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
|
|||||||
return cond_value._copy_with(sliced)
|
return cond_value._copy_with(sliced)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]):
|
||||||
|
"""Compute which concatenated guide frames overlap with a context window.
|
||||||
|
|
||||||
|
Each guide's latent-space start is derived from its first token's pixel-t-start
|
||||||
|
in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the
|
||||||
|
model's temporal_downscale_ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
guide_entries: list of guide_attention_entry dicts
|
||||||
|
keyframe_idxs: per-token pixel coords cond tensor for the modality
|
||||||
|
temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio
|
||||||
|
window_index_list: the window's frame indices into the video portion
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
suffix_indices: indices into the guide_frames tensor for frame selection
|
||||||
|
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
|
||||||
|
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
|
||||||
|
total_overlap: total number of overlapping guide frames
|
||||||
|
"""
|
||||||
|
window_set = set(window_index_list)
|
||||||
|
window_list = list(window_index_list)
|
||||||
|
suffix_indices = []
|
||||||
|
overlap_info = []
|
||||||
|
kf_local_positions = []
|
||||||
|
suffix_base = 0
|
||||||
|
token_offset = 0
|
||||||
|
|
||||||
|
for entry_idx, entry in enumerate(guide_entries):
|
||||||
|
first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item())
|
||||||
|
latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio
|
||||||
|
guide_len = entry["latent_shape"][0]
|
||||||
|
entry_overlap = 0
|
||||||
|
|
||||||
|
for local_offset in range(guide_len):
|
||||||
|
video_pos = latent_start + local_offset
|
||||||
|
if video_pos in window_set:
|
||||||
|
suffix_indices.append(suffix_base + local_offset)
|
||||||
|
kf_local_positions.append(window_list.index(video_pos))
|
||||||
|
entry_overlap += 1
|
||||||
|
|
||||||
|
if entry_overlap > 0:
|
||||||
|
overlap_info.append((entry_idx, entry_overlap))
|
||||||
|
suffix_base += guide_len
|
||||||
|
token_offset += entry["pre_filter_count"]
|
||||||
|
|
||||||
|
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WindowingState:
|
||||||
|
"""Per-modality context windowing state for each step,
|
||||||
|
built using IndexListContextHandler._build_window_state().
|
||||||
|
For non-multimodal models the lists are length 1
|
||||||
|
"""
|
||||||
|
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
|
||||||
|
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
|
||||||
|
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
|
||||||
|
keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation
|
||||||
|
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
|
||||||
|
dim: int = 0 # primary modality temporal dim for context windowing
|
||||||
|
is_multimodal: bool = False
|
||||||
|
temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio
|
||||||
|
|
||||||
|
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
|
||||||
|
"""Reformat window for multimodal contexts by deriving per-modality index lists.
|
||||||
|
Non-multimodal contexts return the input window unchanged.
|
||||||
|
"""
|
||||||
|
if not self.is_multimodal:
|
||||||
|
return window
|
||||||
|
|
||||||
|
x = self.latents[0]
|
||||||
|
primary_total = self.latent_shapes[0][self.dim]
|
||||||
|
primary_overlap = window.context_overlap
|
||||||
|
map_shapes = self.latent_shapes
|
||||||
|
if x.size(self.dim) != primary_total:
|
||||||
|
map_shapes = list(self.latent_shapes)
|
||||||
|
video_shape = list(self.latent_shapes[0])
|
||||||
|
video_shape[self.dim] = x.size(self.dim)
|
||||||
|
map_shapes[0] = torch.Size(video_shape)
|
||||||
|
try:
|
||||||
|
per_modality_indices = model.map_context_window_to_modalities(
|
||||||
|
window.index_list, map_shapes, self.dim)
|
||||||
|
except AttributeError:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
|
||||||
|
modality_windows = {}
|
||||||
|
for mod_idx in range(1, len(self.latents)):
|
||||||
|
modality_total_frames = self.latents[mod_idx].shape[self.dim]
|
||||||
|
ratio = modality_total_frames / primary_total if primary_total > 0 else 1
|
||||||
|
modality_overlap = max(round(primary_overlap * ratio), 0)
|
||||||
|
modality_windows[mod_idx] = IndexListContextWindow(
|
||||||
|
per_modality_indices[mod_idx], dim=self.dim,
|
||||||
|
total_frames=modality_total_frames,
|
||||||
|
context_overlap=modality_overlap)
|
||||||
|
return IndexListContextWindow(
|
||||||
|
window.index_list, dim=self.dim, total_frames=x.shape[self.dim],
|
||||||
|
modality_windows=modality_windows, context_overlap=primary_overlap)
|
||||||
|
|
||||||
|
def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]:
|
||||||
|
"""Slice latents for a context window, injecting guide frames where applicable.
|
||||||
|
For multimodal contexts, uses the modality-specific windows derived in prepare_window().
|
||||||
|
"""
|
||||||
|
sliced = []
|
||||||
|
guide_frame_counts = []
|
||||||
|
for idx in range(len(self.latents)):
|
||||||
|
modality_window = window.get_window_for_modality(idx)
|
||||||
|
retain = retain_index_list if idx == 0 else []
|
||||||
|
s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain)
|
||||||
|
if self.guide_entries[idx] is not None:
|
||||||
|
s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx)
|
||||||
|
else:
|
||||||
|
ng = 0
|
||||||
|
sliced.append(s)
|
||||||
|
guide_frame_counts.append(ng)
|
||||||
|
return sliced, guide_frame_counts
|
||||||
|
|
||||||
|
def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow):
|
||||||
|
"""Strip injected guide frames from per-cond, per-modality outputs in place."""
|
||||||
|
for idx in range(len(self.latents)):
|
||||||
|
if guide_frame_counts[idx] > 0:
|
||||||
|
window_len = len(window.get_window_for_modality(idx).index_list)
|
||||||
|
for ci in range(len(out_per_modality)):
|
||||||
|
out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len)
|
||||||
|
|
||||||
|
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
|
||||||
|
guide_entries = self.guide_entries[modality_idx]
|
||||||
|
guide_frames = self.guide_latents[modality_idx]
|
||||||
|
keyframe_idxs = self.keyframe_idxs[modality_idx]
|
||||||
|
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(
|
||||||
|
guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list)
|
||||||
|
# Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0.
|
||||||
|
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||||
|
if anchor_idx is not None and anchor_idx >= 0:
|
||||||
|
kf_local_pos = [p + 1 for p in kf_local_pos]
|
||||||
|
window.guide_frames_indices = suffix_idx
|
||||||
|
window.guide_overlap_info = overlap_info
|
||||||
|
window.guide_kf_local_positions = kf_local_pos
|
||||||
|
|
||||||
|
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
|
||||||
|
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
|
||||||
|
guide_downscale_factors = []
|
||||||
|
if guide_frame_count > 0:
|
||||||
|
full_H = guide_frames.shape[3]
|
||||||
|
for entry_idx, _ in overlap_info:
|
||||||
|
entry_H = guide_entries[entry_idx]["latent_shape"][1]
|
||||||
|
guide_downscale_factors.append(full_H // entry_H)
|
||||||
|
window.guide_downscale_factors = guide_downscale_factors
|
||||||
|
|
||||||
|
if guide_frame_count > 0:
|
||||||
|
idx = tuple([slice(None)] * self.dim + [suffix_idx])
|
||||||
|
return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count
|
||||||
|
return latent_slice, 0
|
||||||
|
|
||||||
|
def patch_latent_shapes(self, sub_conds, new_shapes):
|
||||||
|
if not self.is_multimodal:
|
||||||
|
return
|
||||||
|
|
||||||
|
for cond_list in sub_conds:
|
||||||
|
if cond_list is None:
|
||||||
|
continue
|
||||||
|
for cond_dict in cond_list:
|
||||||
|
model_conds = cond_dict.get('model_conds', {})
|
||||||
|
if 'latent_shapes' in model_conds:
|
||||||
|
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ContextSchedule:
|
class ContextSchedule:
|
||||||
name: str
|
name: str
|
||||||
@ -162,7 +341,7 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co
|
|||||||
class IndexListContextHandler(ContextHandlerABC):
|
class IndexListContextHandler(ContextHandlerABC):
|
||||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
|
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
|
||||||
causal_window_fix: bool=True):
|
latent_retain_index_list: list[int]=[], causal_window_fix: bool=True):
|
||||||
self.context_schedule = context_schedule
|
self.context_schedule = context_schedule
|
||||||
self.fuse_method = fuse_method
|
self.fuse_method = fuse_method
|
||||||
self.context_length = context_length
|
self.context_length = context_length
|
||||||
@ -174,17 +353,118 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
self.freenoise = freenoise
|
self.freenoise = freenoise
|
||||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||||
self.split_conds_to_windows = split_conds_to_windows
|
self.split_conds_to_windows = split_conds_to_windows
|
||||||
|
self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else []
|
||||||
self.causal_window_fix = causal_window_fix
|
self.causal_window_fix = causal_window_fix
|
||||||
|
|
||||||
self.callbacks = {}
|
self.callbacks = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_latent_shapes(conds):
|
||||||
|
for cond_list in conds:
|
||||||
|
if cond_list is None:
|
||||||
|
continue
|
||||||
|
for cond_dict in cond_list:
|
||||||
|
model_conds = cond_dict.get('model_conds', {})
|
||||||
|
if 'latent_shapes' in model_conds:
|
||||||
|
return model_conds['latent_shapes'].cond
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_guide_entries(conds):
|
||||||
|
for cond_list in conds:
|
||||||
|
if cond_list is None:
|
||||||
|
continue
|
||||||
|
for cond_dict in cond_list:
|
||||||
|
model_conds = cond_dict.get('model_conds', {})
|
||||||
|
entries = model_conds.get('guide_attention_entries')
|
||||||
|
if entries is not None and hasattr(entries, 'cond') and entries.cond:
|
||||||
|
return entries.cond
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_keyframe_idxs(conds):
|
||||||
|
for cond_list in conds:
|
||||||
|
if cond_list is None:
|
||||||
|
continue
|
||||||
|
for cond_dict in cond_list:
|
||||||
|
model_conds = cond_dict.get('model_conds', {})
|
||||||
|
kf = model_conds.get('keyframe_idxs')
|
||||||
|
if kf is not None and hasattr(kf, 'cond') and kf.cond is not None:
|
||||||
|
return kf.cond
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
|
||||||
|
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
|
||||||
|
If guide frames are present on the primary modality, only the video portion is shuffled.
|
||||||
|
"""
|
||||||
|
guide_entries = self._get_guide_entries(conds)
|
||||||
|
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
|
||||||
|
|
||||||
|
latent_shapes = self._get_latent_shapes(conds)
|
||||||
|
if latent_shapes is not None and len(latent_shapes) > 1:
|
||||||
|
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
|
||||||
|
primary_total = latent_shapes[0][self.dim]
|
||||||
|
primary_video_count = modalities[0].size(self.dim) - guide_count
|
||||||
|
apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed)
|
||||||
|
for i in range(1, len(modalities)):
|
||||||
|
mod_total = latent_shapes[i][self.dim]
|
||||||
|
ratio = mod_total / primary_total if primary_total > 0 else 1
|
||||||
|
mod_ctx_len = max(round(self.context_length * ratio), 1)
|
||||||
|
mod_ctx_overlap = max(round(self.context_overlap * ratio), 0)
|
||||||
|
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
|
||||||
|
noise, _ = comfy.utils.pack_latents(modalities)
|
||||||
|
return noise
|
||||||
|
video_count = noise.size(self.dim) - guide_count
|
||||||
|
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
|
||||||
|
return noise
|
||||||
|
|
||||||
|
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState:
|
||||||
|
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
|
||||||
|
latent_shapes = self._get_latent_shapes(conds)
|
||||||
|
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
|
||||||
|
unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in]
|
||||||
|
|
||||||
|
unpacked_latents_list = list(unpacked_latents)
|
||||||
|
guide_latents_list = [None] * len(unpacked_latents)
|
||||||
|
guide_entries_list = [None] * len(unpacked_latents)
|
||||||
|
keyframe_idxs_list = [None] * len(unpacked_latents)
|
||||||
|
|
||||||
|
extracted_guide_entries = self._get_guide_entries(conds)
|
||||||
|
extracted_keyframe_idxs = self._get_keyframe_idxs(conds)
|
||||||
|
|
||||||
|
# Strip guide frames (only from first modality for now)
|
||||||
|
if extracted_guide_entries is not None:
|
||||||
|
guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries)
|
||||||
|
if guide_count > 0:
|
||||||
|
x = unpacked_latents[0]
|
||||||
|
latent_count = x.size(self.dim) - guide_count
|
||||||
|
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
|
||||||
|
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
|
||||||
|
guide_entries_list[0] = extracted_guide_entries
|
||||||
|
keyframe_idxs_list[0] = extracted_keyframe_idxs
|
||||||
|
|
||||||
|
|
||||||
|
return WindowingState(
|
||||||
|
latents=unpacked_latents_list,
|
||||||
|
guide_latents=guide_latents_list,
|
||||||
|
guide_entries=guide_entries_list,
|
||||||
|
keyframe_idxs=keyframe_idxs_list,
|
||||||
|
latent_shapes=latent_shapes,
|
||||||
|
dim=self.dim,
|
||||||
|
is_multimodal=is_multimodal,
|
||||||
|
temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio)
|
||||||
|
|
||||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute
|
||||||
if x_in.size(self.dim) > self.context_length:
|
total_frame_count = window_state.latents[0].size(self.dim)
|
||||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
if total_frame_count > self.context_length:
|
||||||
|
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
|
||||||
if self.cond_retain_index_list:
|
if self.cond_retain_index_list:
|
||||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||||
|
if self.latent_retain_index_list:
|
||||||
|
logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}")
|
||||||
return True
|
return True
|
||||||
|
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||||
@ -275,7 +555,9 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
return resized_cond
|
return resized_cond
|
||||||
|
|
||||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
|
||||||
|
current_timestep = timestep[0].to(sample_sigmas.dtype)
|
||||||
|
mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001)
|
||||||
matches = torch.nonzero(mask)
|
matches = torch.nonzero(mask)
|
||||||
if torch.numel(matches) == 0:
|
if torch.numel(matches) == 0:
|
||||||
return # substep from multi-step sampler: keep self._step from the last full step
|
return # substep from multi-step sampler: keep self._step from the last full step
|
||||||
@ -284,54 +566,98 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows]
|
||||||
return context_windows
|
return context_windows
|
||||||
|
|
||||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
self._model = model
|
self._model = model
|
||||||
self.set_step(timestep, model_options)
|
self.set_step(timestep, model_options)
|
||||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
|
||||||
enumerated_context_windows = list(enumerate(context_windows))
|
|
||||||
|
|
||||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
window_state = self._build_window_state(x_in, conds, model)
|
||||||
|
num_modalities = len(window_state.latents)
|
||||||
|
|
||||||
|
context_windows = self.get_context_windows(model, window_state.latents[0], model_options)
|
||||||
|
enumerated_context_windows = list(enumerate(context_windows))
|
||||||
|
total_windows = len(enumerated_context_windows)
|
||||||
|
|
||||||
|
# Initialize per-modality accumulators (length 1 for single-modality)
|
||||||
|
accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents]
|
||||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
|
||||||
else:
|
else:
|
||||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
|
||||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents]
|
||||||
|
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||||
callback(self, model, x_in, conds, timestep, model_options)
|
callback(self, model, x_in, conds, timestep, model_options)
|
||||||
|
|
||||||
|
# accumulate results from each context window
|
||||||
for enum_window in enumerated_context_windows:
|
for enum_window in enumerated_context_windows:
|
||||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
results = self.evaluate_context_windows(
|
||||||
|
calc_cond_batch, model, x_in, conds, timestep, [enum_window],
|
||||||
|
model_options, window_state=window_state, total_windows=total_windows)
|
||||||
for result in results:
|
for result in results:
|
||||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
# result.sub_conds_out is per-cond, per-modality: list[list[Tensor]]
|
||||||
conds_final, counts_final, biases_final)
|
for mod_idx in range(num_modalities):
|
||||||
|
mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))]
|
||||||
|
modality_window = result.window.get_window_for_modality(mod_idx)
|
||||||
|
self.combine_context_window_results(
|
||||||
|
window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window,
|
||||||
|
result.window_idx, total_windows, timestep,
|
||||||
|
accum[mod_idx], counts[mod_idx], biases[mod_idx])
|
||||||
|
|
||||||
|
# fuse accumulated results into final conds
|
||||||
try:
|
try:
|
||||||
# finalize conds
|
result_out = []
|
||||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
for ci in range(len(conds)):
|
||||||
# relative is already normalized, so return as is
|
finalized = []
|
||||||
del counts_final
|
for mod_idx in range(num_modalities):
|
||||||
return conds_final
|
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
||||||
else:
|
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
||||||
# normalize conds via division by context usage counts
|
f = accum[mod_idx][ci]
|
||||||
for i in range(len(conds_final)):
|
|
||||||
conds_final[i] /= counts_final[i]
|
# if guide frames were injected, append them to the end of the fused latents for the next step
|
||||||
del counts_final
|
if window_state.guide_latents[mod_idx] is not None:
|
||||||
return conds_final
|
f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim)
|
||||||
|
finalized.append(f)
|
||||||
|
|
||||||
|
# pack modalities together if needed
|
||||||
|
if window_state.is_multimodal and len(finalized) > 1:
|
||||||
|
packed, _ = comfy.utils.pack_latents(finalized)
|
||||||
|
else:
|
||||||
|
packed = finalized[0]
|
||||||
|
|
||||||
|
result_out.append(packed)
|
||||||
|
return result_out
|
||||||
finally:
|
finally:
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||||
callback(self, model, x_in, conds, timestep, model_options)
|
callback(self, model, x_in, conds, timestep, model_options)
|
||||||
|
|
||||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds,
|
||||||
model_options, device=None, first_device=None):
|
timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||||
|
model_options, window_state: WindowingState, total_windows: int = None,
|
||||||
|
device=None, first_device=None):
|
||||||
|
"""Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out
|
||||||
|
|
||||||
|
For each window:
|
||||||
|
1. Builds windows (for each modality if multimodal)
|
||||||
|
2. Slices window for each modality
|
||||||
|
3. Injects concatenated latent guide frames where present
|
||||||
|
4. Packs together if needed and calls model
|
||||||
|
5. Unpacks and strips any guides from outputs
|
||||||
|
"""
|
||||||
|
x = window_state.latents[0]
|
||||||
|
|
||||||
results: list[ContextResults] = []
|
results: list[ContextResults] = []
|
||||||
for window_idx, window in enumerated_context_windows:
|
for window_idx, window in enumerated_context_windows:
|
||||||
# allow processing to end between context window executions for faster Cancel
|
# allow processing to end between context window executions for faster Cancel
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
|
# prepare the window accounting for multimodal windows
|
||||||
|
window = window_state.prepare_window(window, model)
|
||||||
|
|
||||||
|
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward.
|
||||||
|
# Set anchor before slice_for_window so the latent slice and downstream cond slices both pick it up.
|
||||||
anchor_applied = False
|
anchor_applied = False
|
||||||
if self.causal_window_fix:
|
if self.causal_window_fix:
|
||||||
anchor_idx = window.index_list[0] - 1
|
anchor_idx = window.index_list[0] - 1
|
||||||
@ -339,27 +665,46 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
window.causal_anchor_index = anchor_idx
|
window.causal_anchor_index = anchor_idx
|
||||||
anchor_applied = True
|
anchor_applied = True
|
||||||
|
|
||||||
|
# slice the window for each modality, injecting guide frames where applicable
|
||||||
|
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device)
|
||||||
|
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||||
|
|
||||||
# update exposed params
|
logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}"
|
||||||
|
+ (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "")
|
||||||
|
)
|
||||||
|
|
||||||
|
# if multimodal, pack modalities together
|
||||||
|
if window_state.is_multimodal and len(sliced) > 1:
|
||||||
|
sub_x, sub_shapes = comfy.utils.pack_latents(sliced)
|
||||||
|
else:
|
||||||
|
sub_x, sub_shapes = sliced[0], [sliced[0].shape]
|
||||||
|
|
||||||
|
# get resized conds for window
|
||||||
model_options["transformer_options"]["context_window"] = window
|
model_options["transformer_options"]["context_window"] = window
|
||||||
# get subsections of x, timestep, conds
|
sub_timestep = window.get_tensor(timestep, dim=0)
|
||||||
sub_x = window.get_tensor(x_in, device)
|
sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds]
|
||||||
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
|
||||||
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
|
||||||
|
|
||||||
|
# if multimodal, patch latent_shapes in conds for correct unpacking in model
|
||||||
|
window_state.patch_latent_shapes(sub_conds, sub_shapes)
|
||||||
|
|
||||||
|
# call model on window
|
||||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||||
if device is not None:
|
|
||||||
for i in range(len(sub_conds_out)):
|
|
||||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
|
||||||
|
|
||||||
# strip causal_window_fix anchor if applied
|
# unpack outputs
|
||||||
|
out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
|
||||||
|
|
||||||
|
# strip causal_window_fix anchor from primary modality before guide strip so window_len math stays correct
|
||||||
if anchor_applied:
|
if anchor_applied:
|
||||||
for i in range(len(sub_conds_out)):
|
for ci in range(len(out_per_modality)):
|
||||||
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
|
t = out_per_modality[ci][0]
|
||||||
|
out_per_modality[ci][0] = t.narrow(self.dim, 1, t.shape[self.dim] - 1)
|
||||||
|
|
||||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
# strip injected guide frames
|
||||||
|
window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window)
|
||||||
|
|
||||||
|
results.append(ContextResults(window_idx, out_per_modality, sub_conds, window))
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -383,7 +728,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
biases_final[i][idx] = bias_total + bias
|
biases_final[i][idx] = bias_total + bias
|
||||||
else:
|
else:
|
||||||
# add conds and counts based on weights of fuse method
|
# add conds and counts based on weights of fuse method
|
||||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap)
|
||||||
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||||
for i in range(len(sub_conds_out)):
|
for i in range(len(sub_conds_out)):
|
||||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||||
@ -393,16 +738,22 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||||
|
|
||||||
|
|
||||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs):
|
||||||
# limit noise_shape length to context_length for more accurate vram use estimation
|
# Scale noise_shape to a single context window so VRAM estimation budgets per-window.
|
||||||
model_options = kwargs.get("model_options", None)
|
model_options = kwargs.get("model_options", None)
|
||||||
if model_options is None:
|
if model_options is None:
|
||||||
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||||
if handler is not None:
|
if handler is not None:
|
||||||
noise_shape = list(noise_shape)
|
noise_shape = list(noise_shape)
|
||||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
is_packed = len(noise_shape) == 3 and noise_shape[1] == 1
|
||||||
return executor(model, noise_shape, *args, **kwargs)
|
if is_packed:
|
||||||
|
# TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a
|
||||||
|
# per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM.
|
||||||
|
pass
|
||||||
|
elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
|
||||||
|
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||||
|
return executor(model, noise_shape, conds, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||||
@ -422,11 +773,12 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
|
|||||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||||
if not handler.freenoise:
|
if not handler.freenoise:
|
||||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
|
||||||
|
conds = [guider.conds.get('positive', guider.conds.get('negative', []))]
|
||||||
|
noise = handler._apply_freenoise(noise, conds, extra_args["seed"])
|
||||||
|
|
||||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||||
model.add_wrapper_with_key(
|
model.add_wrapper_with_key(
|
||||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||||
@ -434,7 +786,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher):
|
|||||||
_sampler_sample_wrapper
|
_sampler_sample_wrapper
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||||
total_dims = len(x_in.shape)
|
total_dims = len(x_in.shape)
|
||||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||||
@ -580,8 +931,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
|||||||
return ContextSchedule(context_schedule, func)
|
return ContextSchedule(context_schedule, func)
|
||||||
|
|
||||||
|
|
||||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None):
|
||||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
context_overlap = handler.context_overlap if context_overlap is None else context_overlap
|
||||||
|
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap)
|
||||||
|
|
||||||
|
|
||||||
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||||
@ -599,18 +951,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
|||||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||||
return weight_sequence
|
return weight_sequence
|
||||||
|
|
||||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs):
|
||||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||||
# only expected overlap is given different weights
|
# only expected overlap is given different weights
|
||||||
weights_torch = torch.ones((length))
|
weights_torch = torch.ones((length))
|
||||||
# blend left-side on all except first window
|
# blend left-side on all except first window
|
||||||
if min(idxs) > 0:
|
if min(idxs) > 0:
|
||||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
ramp_up = torch.linspace(1e-37, 1, context_overlap)
|
||||||
weights_torch[:handler.context_overlap] = ramp_up
|
weights_torch[:context_overlap] = ramp_up
|
||||||
# blend right-side on all except last window
|
# blend right-side on all except last window
|
||||||
if max(idxs) < full_length-1:
|
if max(idxs) < full_length-1:
|
||||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
ramp_down = torch.linspace(1, 1e-37, context_overlap)
|
||||||
weights_torch[-handler.context_overlap:] = ramp_down
|
weights_torch[-context_overlap:] = ramp_down
|
||||||
return weights_torch
|
return weights_torch
|
||||||
|
|
||||||
class ContextFuseMethods:
|
class ContextFuseMethods:
|
||||||
|
|||||||
@ -1028,7 +1028,7 @@ class LTXVModel(LTXBaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
grid_mask = None
|
grid_mask = None
|
||||||
if keyframe_idxs is not None:
|
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
|
||||||
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||||
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||||
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||||
@ -1315,7 +1315,7 @@ class LTXVModel(LTXBaseModel):
|
|||||||
x = x * (1 + scale) + shift
|
x = x * (1 + scale) + shift
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
|
|
||||||
if keyframe_idxs is not None:
|
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
|
||||||
grid_mask = kwargs["grid_mask"]
|
grid_mask = kwargs["grid_mask"]
|
||||||
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
||||||
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import comfy.ldm.lightricks.av_model
|
import comfy.ldm.lightricks.av_model
|
||||||
|
import comfy.ldm.lightricks.symmetric_patchifier
|
||||||
import comfy.context_windows
|
import comfy.context_windows
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from comfy.ldm.cascade.stage_c import StageC
|
from comfy.ldm.cascade.stage_c import StageC
|
||||||
@ -1094,6 +1095,127 @@ class LTXAV(BaseModel):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
|
def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
|
||||||
|
result = [primary_indices]
|
||||||
|
if len(latent_shapes) < 2:
|
||||||
|
return result
|
||||||
|
|
||||||
|
video_total = latent_shapes[0][dim]
|
||||||
|
|
||||||
|
for i in range(1, len(latent_shapes)):
|
||||||
|
mod_total = latent_shapes[i][dim]
|
||||||
|
# Map each primary index to its proportional range of modality indices and
|
||||||
|
# concatenate in order. Preserves wrapped/strided geometry so the modality
|
||||||
|
# attends to the same temporal regions as the primary window.
|
||||||
|
mod_indices = []
|
||||||
|
seen = set()
|
||||||
|
for v_idx in primary_indices:
|
||||||
|
a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1)
|
||||||
|
a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total)
|
||||||
|
if a_end <= a_start:
|
||||||
|
a_end = a_start + 1
|
||||||
|
for a in range(a_start, a_end):
|
||||||
|
if a not in seen:
|
||||||
|
seen.add(a)
|
||||||
|
mod_indices.append(a)
|
||||||
|
result.append(mod_indices)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_guide_entries(conds):
|
||||||
|
for cond_list in conds:
|
||||||
|
if cond_list is None:
|
||||||
|
continue
|
||||||
|
for cond_dict in cond_list:
|
||||||
|
model_conds = cond_dict.get('model_conds', {})
|
||||||
|
entries = model_conds.get('guide_attention_entries')
|
||||||
|
if entries is not None and hasattr(entries, 'cond') and entries.cond:
|
||||||
|
return entries.cond
|
||||||
|
return None
|
||||||
|
|
||||||
|
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
|
# Audio denoise mask — slice using audio modality window
|
||||||
|
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
|
||||||
|
audio_window = window.modality_windows.get(1)
|
||||||
|
if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
|
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
|
||||||
|
return cond_value._copy_with(sliced)
|
||||||
|
|
||||||
|
# Video denoise mask — split into video + guide portions, slice each
|
||||||
|
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
|
cond_tensor = cond_value.cond
|
||||||
|
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
|
||||||
|
if guide_count > 0:
|
||||||
|
T_video = x_in.size(window.dim)
|
||||||
|
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||||
|
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||||
|
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||||
|
suffix_indices = window.guide_frames_indices
|
||||||
|
if suffix_indices:
|
||||||
|
idx = tuple([slice(None)] * window.dim + [suffix_indices])
|
||||||
|
sliced_guide = guide_mask[idx].to(device)
|
||||||
|
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||||
|
else:
|
||||||
|
return cond_value._copy_with(sliced_video)
|
||||||
|
|
||||||
|
# Keyframe indices — regenerate pixel coords for window, select guide positions
|
||||||
|
if cond_key == "keyframe_idxs":
|
||||||
|
kf_local_pos = window.guide_kf_local_positions
|
||||||
|
if not kf_local_pos:
|
||||||
|
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
|
||||||
|
H, W = x_in.shape[3], x_in.shape[4]
|
||||||
|
window_len = len(window.index_list)
|
||||||
|
# account for causal_window_fix anchor in coord space size
|
||||||
|
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||||
|
if anchor_idx is not None and anchor_idx >= 0:
|
||||||
|
window_len += 1
|
||||||
|
patchifier = self.diffusion_model.patchifier
|
||||||
|
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||||
|
scale_factors = self.diffusion_model.vae_scale_factors
|
||||||
|
pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords(
|
||||||
|
latent_coords,
|
||||||
|
scale_factors,
|
||||||
|
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||||
|
tokens = []
|
||||||
|
for pos in kf_local_pos:
|
||||||
|
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
|
||||||
|
pixel_coords = pixel_coords[:, :, tokens, :]
|
||||||
|
|
||||||
|
# Adjust spatial end positions for dilated (downscaled) guides.
|
||||||
|
# Each guide entry may have a different downscale factor; expand the
|
||||||
|
# per-entry factor to cover all tokens belonging to that entry.
|
||||||
|
downscale_factors = window.guide_downscale_factors
|
||||||
|
overlap_info = window.guide_overlap_info
|
||||||
|
if downscale_factors:
|
||||||
|
per_token_factor = []
|
||||||
|
for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors):
|
||||||
|
per_token_factor.extend([dsf] * (overlap_count * H * W))
|
||||||
|
factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype)
|
||||||
|
spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor(
|
||||||
|
scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype,
|
||||||
|
).view(1, -1, 1, 1)
|
||||||
|
pixel_coords[:, 1:, :, 1:] += spatial_end_offset
|
||||||
|
|
||||||
|
B = cond_value.cond.shape[0]
|
||||||
|
if B > 1:
|
||||||
|
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||||
|
return cond_value._copy_with(pixel_coords)
|
||||||
|
|
||||||
|
# Guide attention entries — adjust per-guide counts based on window overlap
|
||||||
|
if cond_key == "guide_attention_entries":
|
||||||
|
overlap_info = window.guide_overlap_info
|
||||||
|
H, W = x_in.shape[3], x_in.shape[4]
|
||||||
|
new_entries = []
|
||||||
|
for entry_idx, overlap_count in overlap_info:
|
||||||
|
e = cond_value.cond[entry_idx]
|
||||||
|
new_entries.append({**e,
|
||||||
|
"pre_filter_count": overlap_count * H * W,
|
||||||
|
"latent_shape": [overlap_count, H, W]})
|
||||||
|
return cond_value._copy_with(new_entries)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
class HunyuanVideo(BaseModel):
|
class HunyuanVideo(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||||
|
|||||||
@ -14,21 +14,22 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
description="Manually set context windows.",
|
description="Manually set context windows.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||||
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True),
|
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
|
||||||
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True),
|
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
|
||||||
io.Combo.Input("context_schedule", options=[
|
io.Combo.Input("context_schedule", options=[
|
||||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||||
comfy.context_windows.ContextSchedules.BATCHED,
|
comfy.context_windows.ContextSchedules.BATCHED,
|
||||||
], tooltip="The stride of the context window."),
|
], default=comfy.context_windows.ContextSchedules.STATIC_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
|
||||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."),
|
||||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||||
|
io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."),
|
||||||
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
|
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
@ -39,7 +40,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model:
|
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[], causal_window_fix: bool=True) -> io.Model:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||||
@ -52,6 +53,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
freenoise=freenoise,
|
freenoise=freenoise,
|
||||||
cond_retain_index_list=cond_retain_index_list,
|
cond_retain_index_list=cond_retain_index_list,
|
||||||
split_conds_to_windows=split_conds_to_windows,
|
split_conds_to_windows=split_conds_to_windows,
|
||||||
|
latent_retain_index_list=latent_retain_index_list,
|
||||||
causal_window_fix=causal_window_fix,
|
causal_window_fix=causal_window_fix,
|
||||||
)
|
)
|
||||||
# make memory usage calculation only take into account the context window latents
|
# make memory usage calculation only take into account the context window latents
|
||||||
@ -65,33 +67,70 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
schema = super().define_schema()
|
schema = super().define_schema()
|
||||||
schema.node_id = "WanContextWindowsManual"
|
schema.node_id = "WanContextWindowsManual"
|
||||||
schema.display_name = "WAN Context Windows (Manual)"
|
schema.display_name = "Wan Context Windows"
|
||||||
schema.description = "Manually set context windows for WAN-like models (dim=2)."
|
schema.description = "Set context windows for Wan-like models."
|
||||||
schema.inputs = [
|
schema.inputs = [
|
||||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||||
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True),
|
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window in real frames. Must be 4*n + 1."),
|
||||||
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True),
|
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window in real frames."),
|
||||||
io.Combo.Input("context_schedule", options=[
|
io.Combo.Input("context_schedule", options=[
|
||||||
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||||
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||||
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||||
comfy.context_windows.ContextSchedules.BATCHED,
|
comfy.context_windows.ContextSchedules.BATCHED,
|
||||||
], tooltip="The stride of the context window."),
|
], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
|
||||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
||||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True),
|
||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True),
|
||||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first I2V frame in every context window (may help retain initial reference)."),
|
||||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True),
|
||||||
]
|
]
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
|
||||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
retain_first_frame: bool=False, split_conds_to_windows: bool=False) -> io.Model:
|
||||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
context_overlap = max(context_overlap // 4, 0) # at least overlap 0
|
||||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
retain_index_list = "0" if retain_first_frame else ""
|
||||||
|
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVContextWindowsNode(ContextWindowsManualNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
schema = super().define_schema()
|
||||||
|
schema.node_id = "LTXVContextWindows"
|
||||||
|
schema.display_name = "LTXV Context Windows"
|
||||||
|
schema.description = "Set context windows for LTXV-like models."
|
||||||
|
schema.inputs = [
|
||||||
|
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||||
|
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=8, default=145, tooltip="The length of the context window in real frames. Must be 8*n + 1."),
|
||||||
|
io.Int.Input("context_overlap", min=0, step=8, default=40, tooltip="The overlap of the context window in real frames."),
|
||||||
|
io.Combo.Input("context_schedule", options=[
|
||||||
|
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||||
|
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||||
|
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||||
|
comfy.context_windows.ContextSchedules.BATCHED,
|
||||||
|
], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."),
|
||||||
|
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True),
|
||||||
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True),
|
||||||
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
|
io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True),
|
||||||
|
io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first latent frame in every context window (may help retain initial reference)."),
|
||||||
|
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True),
|
||||||
|
]
|
||||||
|
return schema
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, fuse_method: str, freenoise: bool,
|
||||||
|
retain_first_frame: bool=False, split_conds_to_windows: bool=False, context_stride: int=1, closed_loop: bool=False) -> io.Model:
|
||||||
|
context_length = max(((context_length - 1) // 8) + 1, 1) # at least length 1
|
||||||
|
context_overlap = max(context_overlap // 8, 0) # at least overlap 0
|
||||||
|
retain_index_list = "0" if retain_first_frame else ""
|
||||||
|
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise,
|
||||||
|
cond_retain_index_list=retain_index_list, latent_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||||
|
|
||||||
|
|
||||||
class ContextWindowsExtension(ComfyExtension):
|
class ContextWindowsExtension(ComfyExtension):
|
||||||
@ -99,6 +138,7 @@ class ContextWindowsExtension(ComfyExtension):
|
|||||||
return [
|
return [
|
||||||
ContextWindowsManualNode,
|
ContextWindowsManualNode,
|
||||||
WanContextWindowsManualNode,
|
WanContextWindowsManualNode,
|
||||||
|
LTXVContextWindowsNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
def comfy_entrypoint():
|
def comfy_entrypoint():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user