mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 15:32:32 +08:00
Merge 6442392810 into 3d816db07f
This commit is contained in:
commit
eea35bb8c1
@ -8,6 +8,8 @@ from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
import comfy.conds
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
@ -51,12 +53,18 @@ class ContextHandlerABC(ABC):
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.context_overlap = context_overlap
|
||||
self.dim = dim
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow}
|
||||
self.guide_frames_indices: list[int] = []
|
||||
self.guide_overlap_info: list[tuple[int, int]] = []
|
||||
self.guide_kf_local_positions: list[int] = []
|
||||
self.guide_downscale_factors: list[int] = []
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||
if dim is None:
|
||||
@ -81,6 +89,11 @@ class IndexListContextWindow(ContextWindowABC):
|
||||
region_idx = int(self.center_ratio * num_regions)
|
||||
return min(max(region_idx, 0), num_regions - 1)
|
||||
|
||||
def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow':
|
||||
if modality_idx == 0:
|
||||
return self
|
||||
return self.modality_windows[modality_idx]
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
@ -137,6 +150,157 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
|
||||
def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int]):
|
||||
"""Compute which concatenated guide frames overlap with a context window.
|
||||
|
||||
Args:
|
||||
guide_entries: list of guide_attention_entry dicts
|
||||
window_index_list: the window's frame indices into the video portion
|
||||
|
||||
Returns:
|
||||
suffix_indices: indices into the guide_frames tensor for frame selection
|
||||
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
|
||||
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
|
||||
total_overlap: total number of overlapping guide frames
|
||||
"""
|
||||
window_set = set(window_index_list)
|
||||
window_list = list(window_index_list)
|
||||
suffix_indices = []
|
||||
overlap_info = []
|
||||
kf_local_positions = []
|
||||
suffix_base = 0
|
||||
|
||||
for entry_idx, entry in enumerate(guide_entries):
|
||||
latent_start = entry.get("latent_start", None)
|
||||
if latent_start is None:
|
||||
raise ValueError("guide_attention_entry missing required 'latent_start'.")
|
||||
guide_len = entry["latent_shape"][0]
|
||||
entry_overlap = 0
|
||||
|
||||
for local_offset in range(guide_len):
|
||||
video_pos = latent_start + local_offset
|
||||
if video_pos in window_set:
|
||||
suffix_indices.append(suffix_base + local_offset)
|
||||
kf_local_positions.append(window_list.index(video_pos))
|
||||
entry_overlap += 1
|
||||
|
||||
if entry_overlap > 0:
|
||||
overlap_info.append((entry_idx, entry_overlap))
|
||||
suffix_base += guide_len
|
||||
|
||||
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WindowingState:
|
||||
"""Per-modality context windowing state for each step,
|
||||
built using IndexListContextHandler._build_window_state().
|
||||
For non-multimodal models the lists are length 1
|
||||
"""
|
||||
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
|
||||
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
|
||||
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
|
||||
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
|
||||
dim: int = 0 # primary modality temporal dim for context windowing
|
||||
is_multimodal: bool = False
|
||||
|
||||
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
|
||||
"""Reformat window for multimodal contexts by deriving per-modality index lists.
|
||||
Non-multimodal contexts return the input window unchanged.
|
||||
"""
|
||||
if not self.is_multimodal:
|
||||
return window
|
||||
|
||||
x = self.latents[0]
|
||||
primary_total = self.latent_shapes[0][self.dim]
|
||||
primary_overlap = window.context_overlap
|
||||
map_shapes = self.latent_shapes
|
||||
if x.size(self.dim) != primary_total:
|
||||
map_shapes = list(self.latent_shapes)
|
||||
video_shape = list(self.latent_shapes[0])
|
||||
video_shape[self.dim] = x.size(self.dim)
|
||||
map_shapes[0] = torch.Size(video_shape)
|
||||
try:
|
||||
per_modality_indices = model.map_context_window_to_modalities(
|
||||
window.index_list, map_shapes, self.dim)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
|
||||
modality_windows = {}
|
||||
for mod_idx in range(1, len(self.latents)):
|
||||
modality_total_frames = self.latents[mod_idx].shape[self.dim]
|
||||
ratio = modality_total_frames / primary_total if primary_total > 0 else 1
|
||||
modality_overlap = max(round(primary_overlap * ratio), 0)
|
||||
modality_windows[mod_idx] = IndexListContextWindow(
|
||||
per_modality_indices[mod_idx], dim=self.dim,
|
||||
total_frames=modality_total_frames,
|
||||
context_overlap=modality_overlap)
|
||||
return IndexListContextWindow(
|
||||
window.index_list, dim=self.dim, total_frames=x.shape[self.dim],
|
||||
modality_windows=modality_windows, context_overlap=primary_overlap)
|
||||
|
||||
def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]:
|
||||
"""Slice latents for a context window, injecting guide frames where applicable.
|
||||
For multimodal contexts, uses the modality-specific windows derived in prepare_window().
|
||||
"""
|
||||
sliced = []
|
||||
guide_frame_counts = []
|
||||
for idx in range(len(self.latents)):
|
||||
modality_window = window.get_window_for_modality(idx)
|
||||
retain = retain_index_list if idx == 0 else []
|
||||
s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain)
|
||||
if self.guide_entries[idx] is not None:
|
||||
s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx)
|
||||
else:
|
||||
ng = 0
|
||||
sliced.append(s)
|
||||
guide_frame_counts.append(ng)
|
||||
return sliced, guide_frame_counts
|
||||
|
||||
def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow):
|
||||
"""Strip injected guide frames from per-cond, per-modality outputs in place."""
|
||||
for idx in range(len(self.latents)):
|
||||
if guide_frame_counts[idx] > 0:
|
||||
window_len = len(window.get_window_for_modality(idx).index_list)
|
||||
for ci in range(len(out_per_modality)):
|
||||
out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len)
|
||||
|
||||
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
|
||||
guide_entries = self.guide_entries[modality_idx]
|
||||
guide_frames = self.guide_latents[modality_idx]
|
||||
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(guide_entries, window.index_list)
|
||||
window.guide_frames_indices = suffix_idx
|
||||
window.guide_overlap_info = overlap_info
|
||||
window.guide_kf_local_positions = kf_local_pos
|
||||
|
||||
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
|
||||
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
|
||||
guide_downscale_factors = []
|
||||
if guide_frame_count > 0:
|
||||
full_H = guide_frames.shape[3]
|
||||
for entry_idx, _ in overlap_info:
|
||||
entry_H = guide_entries[entry_idx]["latent_shape"][1]
|
||||
guide_downscale_factors.append(full_H // entry_H)
|
||||
window.guide_downscale_factors = guide_downscale_factors
|
||||
|
||||
if guide_frame_count > 0:
|
||||
idx = tuple([slice(None)] * self.dim + [suffix_idx])
|
||||
return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count
|
||||
return latent_slice, 0
|
||||
|
||||
def patch_latent_shapes(self, sub_conds, new_shapes):
|
||||
if not self.is_multimodal:
|
||||
return
|
||||
|
||||
for cond_list in sub_conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
if 'latent_shapes' in model_conds:
|
||||
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
@ -165,13 +329,94 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
@staticmethod
|
||||
def _get_latent_shapes(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
if 'latent_shapes' in model_conds:
|
||||
return model_conds['latent_shapes'].cond
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_guide_entries(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
entries = model_conds.get('guide_attention_entries')
|
||||
if entries is not None and hasattr(entries, 'cond') and entries.cond:
|
||||
return entries.cond
|
||||
return None
|
||||
|
||||
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
|
||||
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
|
||||
If guide frames are present on the primary modality, only the video portion is shuffled.
|
||||
"""
|
||||
guide_entries = self._get_guide_entries(conds)
|
||||
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
|
||||
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
if latent_shapes is not None and len(latent_shapes) > 1:
|
||||
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
|
||||
primary_total = latent_shapes[0][self.dim]
|
||||
primary_video_count = modalities[0].size(self.dim) - guide_count
|
||||
apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed)
|
||||
for i in range(1, len(modalities)):
|
||||
mod_total = latent_shapes[i][self.dim]
|
||||
ratio = mod_total / primary_total if primary_total > 0 else 1
|
||||
mod_ctx_len = max(round(self.context_length * ratio), 1)
|
||||
mod_ctx_overlap = max(round(self.context_overlap * ratio), 0)
|
||||
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
|
||||
noise, _ = comfy.utils.pack_latents(modalities)
|
||||
return noise
|
||||
video_count = noise.size(self.dim) - guide_count
|
||||
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
|
||||
return noise
|
||||
|
||||
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState:
|
||||
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
|
||||
unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in]
|
||||
|
||||
unpacked_latents_list = list(unpacked_latents)
|
||||
guide_latents_list = [None] * len(unpacked_latents)
|
||||
guide_entries_list = [None] * len(unpacked_latents)
|
||||
|
||||
extracted_guide_entries = self._get_guide_entries(conds)
|
||||
|
||||
# Strip guide frames (only from first modality for now)
|
||||
if extracted_guide_entries is not None:
|
||||
guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries)
|
||||
if guide_count > 0:
|
||||
x = unpacked_latents[0]
|
||||
latent_count = x.size(self.dim) - guide_count
|
||||
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
|
||||
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
|
||||
guide_entries_list[0] = extracted_guide_entries
|
||||
|
||||
|
||||
return WindowingState(
|
||||
latents=unpacked_latents_list,
|
||||
guide_latents=guide_latents_list,
|
||||
guide_entries=guide_entries_list,
|
||||
latent_shapes=latent_shapes,
|
||||
dim=self.dim,
|
||||
is_multimodal=is_multimodal)
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||
window_state = self._build_window_state(x_in, conds) # build window_state to check frame counts, will be built again in execute
|
||||
total_frame_count = window_state.latents[0].size(self.dim)
|
||||
if total_frame_count > self.context_length:
|
||||
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
return True
|
||||
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
|
||||
return False
|
||||
|
||||
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||
@ -262,7 +507,9 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
|
||||
current_timestep = timestep[0].to(sample_sigmas.dtype)
|
||||
mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
return # substep from multi-step sampler: keep self._step from the last full step
|
||||
@ -271,68 +518,128 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self._model = model
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||
window_state = self._build_window_state(x_in, conds)
|
||||
num_modalities = len(window_state.latents)
|
||||
|
||||
context_windows = self.get_context_windows(model, window_state.latents[0], model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
total_windows = len(enumerated_context_windows)
|
||||
|
||||
# Initialize per-modality accumulators (length 1 for single-modality)
|
||||
accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
|
||||
else:
|
||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
|
||||
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
# accumulate results from each context window
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||
results = self.evaluate_context_windows(
|
||||
calc_cond_batch, model, x_in, conds, timestep, [enum_window],
|
||||
model_options, window_state=window_state, total_windows=total_windows)
|
||||
for result in results:
|
||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||
conds_final, counts_final, biases_final)
|
||||
# result.sub_conds_out is per-cond, per-modality: list[list[Tensor]]
|
||||
for mod_idx in range(num_modalities):
|
||||
mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))]
|
||||
modality_window = result.window.get_window_for_modality(mod_idx)
|
||||
self.combine_context_window_results(
|
||||
window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window,
|
||||
result.window_idx, total_windows, timestep,
|
||||
accum[mod_idx], counts[mod_idx], biases[mod_idx])
|
||||
|
||||
# fuse accumulated results into final conds
|
||||
try:
|
||||
# finalize conds
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
# relative is already normalized, so return as is
|
||||
del counts_final
|
||||
return conds_final
|
||||
else:
|
||||
# normalize conds via division by context usage counts
|
||||
for i in range(len(conds_final)):
|
||||
conds_final[i] /= counts_final[i]
|
||||
del counts_final
|
||||
return conds_final
|
||||
result_out = []
|
||||
for ci in range(len(conds)):
|
||||
finalized = []
|
||||
for mod_idx in range(num_modalities):
|
||||
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
||||
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
||||
f = accum[mod_idx][ci]
|
||||
|
||||
# if guide frames were injected, append them to the end of the fused latents for the next step
|
||||
if window_state.guide_latents[mod_idx] is not None:
|
||||
f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim)
|
||||
finalized.append(f)
|
||||
|
||||
# pack modalities together if needed
|
||||
if window_state.is_multimodal and len(finalized) > 1:
|
||||
packed, _ = comfy.utils.pack_latents(finalized)
|
||||
else:
|
||||
packed = finalized[0]
|
||||
|
||||
result_out.append(packed)
|
||||
return result_out
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, device=None, first_device=None):
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds,
|
||||
timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, window_state: WindowingState, total_windows: int = None,
|
||||
device=None, first_device=None):
|
||||
"""Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out
|
||||
|
||||
For each window:
|
||||
1. Builds windows (for each modality if multimodal)
|
||||
2. Slices window for each modality
|
||||
3. Injects concatenated latent guide frames where present
|
||||
4. Packs together if needed and calls model
|
||||
5. Unpacks and strips any guides from outputs
|
||||
"""
|
||||
x = window_state.latents[0]
|
||||
|
||||
results: list[ContextResults] = []
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# prepare the window accounting for multimodal windows
|
||||
window = window_state.prepare_window(window, model)
|
||||
|
||||
# slice the window for each modality, injecting guide frames where applicable
|
||||
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.cond_retain_index_list, device)
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
# update exposed params
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
# get subsections of x, timestep, conds
|
||||
sub_x = window.get_tensor(x_in, device)
|
||||
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
||||
logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}"
|
||||
+ (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "")
|
||||
)
|
||||
|
||||
# if multimodal, pack modalities together
|
||||
if window_state.is_multimodal and len(sliced) > 1:
|
||||
sub_x, sub_shapes = comfy.utils.pack_latents(sliced)
|
||||
else:
|
||||
sub_x, sub_shapes = sliced[0], [sliced[0].shape]
|
||||
|
||||
# get resized conds for window
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
sub_timestep = window.get_tensor(timestep, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds]
|
||||
|
||||
# if multimodal, patch latent_shapes in conds for correct unpacking in model
|
||||
window_state.patch_latent_shapes(sub_conds, sub_shapes)
|
||||
|
||||
# call model on window
|
||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
|
||||
# unpack outputs and strip guide frames
|
||||
out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
|
||||
window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window)
|
||||
|
||||
results.append(ContextResults(window_idx, out_per_modality, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
@ -356,7 +663,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
else:
|
||||
# add conds and counts based on weights of fuse method
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap)
|
||||
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||
for i in range(len(sub_conds_out)):
|
||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||
@ -366,16 +673,22 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||
|
||||
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
||||
# limit noise_shape length to context_length for more accurate vram use estimation
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs):
|
||||
# Scale noise_shape to a single context window so VRAM estimation budgets per-window.
|
||||
model_options = kwargs.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is not None:
|
||||
noise_shape = list(noise_shape)
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, *args, **kwargs)
|
||||
is_packed = len(noise_shape) == 3 and noise_shape[1] == 1
|
||||
if is_packed:
|
||||
# TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a
|
||||
# per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM.
|
||||
pass
|
||||
elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, conds, *args, **kwargs)
|
||||
|
||||
|
||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
@ -395,11 +708,12 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
|
||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
if not handler.freenoise:
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||
|
||||
conds = [guider.conds.get('positive', guider.conds.get('negative', []))]
|
||||
noise = handler._apply_freenoise(noise, conds, extra_args["seed"])
|
||||
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
|
||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||
@ -407,7 +721,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
_sampler_sample_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
@ -553,8 +866,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||
return ContextSchedule(context_schedule, func)
|
||||
|
||||
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None):
|
||||
context_overlap = handler.context_overlap if context_overlap is None else context_overlap
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap)
|
||||
|
||||
|
||||
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||
@ -572,18 +886,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||
return weight_sequence
|
||||
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs):
|
||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||
# only expected overlap is given different weights
|
||||
weights_torch = torch.ones((length))
|
||||
# blend left-side on all except first window
|
||||
if min(idxs) > 0:
|
||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||
weights_torch[:handler.context_overlap] = ramp_up
|
||||
ramp_up = torch.linspace(1e-37, 1, context_overlap)
|
||||
weights_torch[:context_overlap] = ramp_up
|
||||
# blend right-side on all except last window
|
||||
if max(idxs) < full_length-1:
|
||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||
weights_torch[-handler.context_overlap:] = ramp_down
|
||||
ramp_down = torch.linspace(1, 1e-37, context_overlap)
|
||||
weights_torch[-context_overlap:] = ramp_down
|
||||
return weights_torch
|
||||
|
||||
class ContextFuseMethods:
|
||||
|
||||
@ -1028,7 +1028,7 @@ class LTXVModel(LTXBaseModel):
|
||||
)
|
||||
|
||||
grid_mask = None
|
||||
if keyframe_idxs is not None:
|
||||
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
|
||||
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||
@ -1315,7 +1315,7 @@ class LTXVModel(LTXBaseModel):
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
if keyframe_idxs is not None:
|
||||
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
|
||||
grid_mask = kwargs["grid_mask"]
|
||||
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
||||
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
||||
|
||||
@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import comfy.ldm.lightricks.av_model
|
||||
import comfy.ldm.lightricks.symmetric_patchifier
|
||||
import comfy.context_windows
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
@ -1083,6 +1084,123 @@ class LTXAV(BaseModel):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
|
||||
result = [primary_indices]
|
||||
if len(latent_shapes) < 2:
|
||||
return result
|
||||
|
||||
video_total = latent_shapes[0][dim]
|
||||
|
||||
for i in range(1, len(latent_shapes)):
|
||||
mod_total = latent_shapes[i][dim]
|
||||
# Map each primary index to its proportional range of modality indices and
|
||||
# concatenate in order. Preserves wrapped/strided geometry so the modality
|
||||
# attends to the same temporal regions as the primary window.
|
||||
mod_indices = []
|
||||
seen = set()
|
||||
for v_idx in primary_indices:
|
||||
a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1)
|
||||
a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total)
|
||||
if a_end <= a_start:
|
||||
a_end = a_start + 1
|
||||
for a in range(a_start, a_end):
|
||||
if a not in seen:
|
||||
seen.add(a)
|
||||
mod_indices.append(a)
|
||||
result.append(mod_indices)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_guide_entries(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
entries = model_conds.get('guide_attention_entries')
|
||||
if entries is not None and hasattr(entries, 'cond') and entries.cond:
|
||||
return entries.cond
|
||||
return None
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
# Audio denoise mask — slice using audio modality window
|
||||
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
|
||||
audio_window = window.modality_windows.get(1)
|
||||
if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
# Video denoise mask — split into video + guide portions, slice each
|
||||
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
cond_tensor = cond_value.cond
|
||||
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
|
||||
if guide_count > 0:
|
||||
T_video = x_in.size(window.dim)
|
||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||
suffix_indices = window.guide_frames_indices
|
||||
if suffix_indices:
|
||||
idx = tuple([slice(None)] * window.dim + [suffix_indices])
|
||||
sliced_guide = guide_mask[idx].to(device)
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
else:
|
||||
return cond_value._copy_with(sliced_video)
|
||||
|
||||
# Keyframe indices — regenerate pixel coords for window, select guide positions
|
||||
if cond_key == "keyframe_idxs":
|
||||
kf_local_pos = window.guide_kf_local_positions
|
||||
if not kf_local_pos:
|
||||
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
window_len = len(window.index_list)
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
scale_factors = self.diffusion_model.vae_scale_factors
|
||||
pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
tokens = []
|
||||
for pos in kf_local_pos:
|
||||
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
|
||||
pixel_coords = pixel_coords[:, :, tokens, :]
|
||||
|
||||
# Adjust spatial end positions for dilated (downscaled) guides.
|
||||
# Each guide entry may have a different downscale factor; expand the
|
||||
# per-entry factor to cover all tokens belonging to that entry.
|
||||
downscale_factors = window.guide_downscale_factors
|
||||
overlap_info = window.guide_overlap_info
|
||||
if downscale_factors:
|
||||
per_token_factor = []
|
||||
for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors):
|
||||
per_token_factor.extend([dsf] * (overlap_count * H * W))
|
||||
factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype)
|
||||
spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor(
|
||||
scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype,
|
||||
).view(1, -1, 1, 1)
|
||||
pixel_coords[:, 1:, :, 1:] += spatial_end_offset
|
||||
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
|
||||
# Guide attention entries — adjust per-guide counts based on window overlap
|
||||
if cond_key == "guide_attention_entries":
|
||||
overlap_info = window.guide_overlap_info
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
new_entries = []
|
||||
for entry_idx, overlap_count in overlap_info:
|
||||
e = cond_value.cond[entry_idx]
|
||||
new_entries.append({**e,
|
||||
"pre_filter_count": overlap_count * H * W,
|
||||
"latent_shape": [overlap_count, H, W]})
|
||||
return cond_value._copy_with(new_entries)
|
||||
|
||||
return None
|
||||
|
||||
class HunyuanVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
|
||||
@ -135,7 +135,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
|
||||
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, latent_start=0):
|
||||
"""Append a guide_attention_entry to both positive and negative conditioning.
|
||||
|
||||
Each entry tracks one guide reference for per-reference attention control.
|
||||
@ -146,6 +146,7 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
|
||||
"strength": strength,
|
||||
"pixel_mask": None,
|
||||
"latent_shape": latent_shape,
|
||||
"latent_start": latent_start,
|
||||
}
|
||||
results = []
|
||||
for cond in (positive, negative):
|
||||
@ -362,6 +363,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
|
||||
positive, negative = _append_guide_attention_entry(
|
||||
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
||||
latent_start=latent_idx,
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user