diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 964f7dd33..a9f456426 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Callable import torch import numpy as np import collections @@ -60,6 +60,10 @@ class IndexListContextWindow(ContextWindowABC): self.total_frames = total_frames self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow} + self.guide_frames_indices: list[int] = [] + self.guide_overlap_info: list[tuple[int, int]] = [] + self.guide_kf_local_positions: list[int] = [] + self.guide_downscale_factors: list[int] = [] def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: @@ -84,6 +88,11 @@ class IndexListContextWindow(ContextWindowABC): region_idx = int(self.center_ratio * num_regions) return min(max(region_idx, 0), num_regions - 1) + def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow': + if modality_idx == 0: + return self + return self.modality_windows[modality_idx] + class IndexListCallbacks: EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" @@ -181,46 +190,109 @@ def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) -def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWindowABC, - window_data: 'WindowingContext', dim: int) -> tuple[torch.Tensor, int]: - """Inject overlapping guide frames into a context window slice. - - Determines which guide frames overlap with this window's indices, concatenates - them onto the video slice, and sets window attributes for downstream conditioning resize. - - Returns (augmented_slice, num_guide_frames_added). - """ - guide_entries = window_data.aux_data["guide_entries"] - guide_frames = window_data.guide_frames - overlap = compute_guide_overlap(guide_entries, window.index_list) - suffix_idx, overlap_info, kf_local_pos, guide_frame_count = overlap - window.guide_frames_indices = suffix_idx - window.guide_overlap_info = overlap_info - window.guide_kf_local_positions = kf_local_pos - - # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims. - # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims. - guide_downscale_factors = [] - if guide_frame_count > 0: - full_H = guide_frames.shape[3] - for entry_idx, _ in overlap_info: - entry_H = guide_entries[entry_idx]["latent_shape"][1] - guide_downscale_factors.append(full_H // entry_H) - window.guide_downscale_factors = guide_downscale_factors - - if guide_frame_count > 0: - idx = tuple([slice(None)] * dim + [suffix_idx]) - sliced_guide = guide_frames[idx] - return torch.cat([video_slice, sliced_guide], dim=dim), guide_frame_count - return video_slice, 0 - - @dataclass -class WindowingContext: - tensor: torch.Tensor - guide_frames: torch.Tensor | None - aux_data: Any - latent_shapes: list | None +class WindowingState: + """Per-modality context windowing state for each step, + built using IndexListContextHandler._build_window_state(). + For non-multimodal models the lists are length 1 + """ + latents: list[torch.Tensor] # per-modality working latents (guide frames stripped) + guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents + guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata + latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal) + dim: int = 0 # primary modality temporal dim for context windowing + is_multimodal: bool = False + + def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow: + """Reformat window for multimodal contexts by deriving per-modality index lists. + Non-multimodal contexts return the input window unchanged. + """ + if not self.is_multimodal: + return window + + x = self.latents[0] + map_shapes = self.latent_shapes + if x.size(self.dim) != self.latent_shapes[0][self.dim]: + map_shapes = list(self.latent_shapes) + video_shape = list(self.latent_shapes[0]) + video_shape[self.dim] = x.size(self.dim) + map_shapes[0] = torch.Size(video_shape) + try: + per_modality_indices = model.map_context_window_to_modalities( + window.index_list, map_shapes, self.dim) + except AttributeError: + raise NotImplementedError( + f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.") + modality_windows = {} + for mod_idx in range(1, len(self.latents)): + modality_windows[mod_idx] = IndexListContextWindow( + per_modality_indices[mod_idx], dim=self.dim, + total_frames=self.latents[mod_idx].shape[self.dim]) + return IndexListContextWindow( + window.index_list, dim=self.dim, total_frames=x.shape[self.dim], + modality_windows=modality_windows) + + def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]: + """Slice latents for a context window, injecting guide frames where applicable. + For multimodal contexts, uses the modality-specific windows derived in prepare_window(). + """ + sliced = [] + guide_frame_counts = [] + for idx in range(len(self.latents)): + modality_window = window.get_window_for_modality(idx) + retain = retain_index_list if idx == 0 else [] + s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain) + if self.guide_entries[idx] is not None: + s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx) + else: + ng = 0 + sliced.append(s) + guide_frame_counts.append(ng) + return sliced, guide_frame_counts + + def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow): + """Strip injected guide frames from per-cond, per-modality outputs in place.""" + for idx in range(len(self.latents)): + if guide_frame_counts[idx] > 0: + window_len = len(window.get_window_for_modality(idx).index_list) + for ci in range(len(out_per_modality)): + out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len) + + def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]: + guide_entries = self.guide_entries[modality_idx] + guide_frames = self.guide_latents[modality_idx] + suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(guide_entries, window.index_list) + window.guide_frames_indices = suffix_idx + window.guide_overlap_info = overlap_info + window.guide_kf_local_positions = kf_local_pos + + # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims. + # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims. + guide_downscale_factors = [] + if guide_frame_count > 0: + full_H = guide_frames.shape[3] + for entry_idx, _ in overlap_info: + entry_H = guide_entries[entry_idx]["latent_shape"][1] + guide_downscale_factors.append(full_H // entry_H) + window.guide_downscale_factors = guide_downscale_factors + + if guide_frame_count > 0: + idx = tuple([slice(None)] * self.dim + [suffix_idx]) + return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count + return latent_slice, 0 + + def patch_latent_shapes(self, sub_conds, new_shapes): + if not self.is_multimodal: + return + + for cond_list in sub_conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) + @dataclass class ContextSchedule: @@ -261,37 +333,35 @@ class IndexListContextHandler(ContextHandlerABC): return model_conds['latent_shapes'].cond return None - @staticmethod - def _unpack(combined_latent, latent_shapes): + def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor: + """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.""" + latent_shapes = self._get_latent_shapes(conds) if latent_shapes is not None and len(latent_shapes) > 1: - return comfy.utils.unpack_latents(combined_latent, latent_shapes) - return [combined_latent] + modalities = comfy.utils.unpack_latents(noise, latent_shapes) + primary_total = latent_shapes[0][self.dim] + modalities[0] = apply_freenoise(modalities[0], self.dim, self.context_length, self.context_overlap, seed) + for i in range(1, len(modalities)): + mod_total = latent_shapes[i][self.dim] + ratio = mod_total / primary_total if primary_total > 0 else 1 + mod_ctx_len = max(round(self.context_length * ratio), 1) + mod_ctx_overlap = max(round(self.context_overlap * ratio), 0) + modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed) + noise, _ = comfy.utils.pack_latents(modalities) + return noise + return apply_freenoise(noise, self.dim, self.context_length, self.context_overlap, seed) - @staticmethod - def _pack(latents): - if len(latents) > 1: - return comfy.utils.pack_latents(latents) - return latents[0], [latents[0].shape] - - @staticmethod - def _patch_latent_shapes(sub_conds, new_shapes): - for cond_list in sub_conds: - if cond_list is None: - continue - for cond_dict in cond_list: - model_conds = cond_dict.get('model_conds', {}) - if 'latent_shapes' in model_conds: - model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) - - def _build_window_data(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingContext: + def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingState: + """Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds.""" latent_shapes = self._get_latent_shapes(conds) is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 - if is_multimodal: - video_latent = comfy.utils.unpack_latents(x_in, latent_shapes)[0] - else: - video_latent = x_in + unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in] - guide_entries = None + unpacked_latents_list = list(unpacked_latents) + guide_latents_list = [None] * len(unpacked_latents) + guide_entries_list = [None] * len(unpacked_latents) + + # Scan for 'guide_attention_entries' in conds + extracted_guide_entries = None for cond_list in conds: if cond_list is None: continue @@ -299,37 +369,39 @@ class IndexListContextHandler(ContextHandlerABC): model_conds = cond_dict.get('model_conds', {}) entries = model_conds.get('guide_attention_entries') if entries is not None and hasattr(entries, 'cond') and entries.cond: - guide_entries = entries.cond + extracted_guide_entries = entries.cond break - if guide_entries is not None: + if extracted_guide_entries is not None: break - guide_frame_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries is not None else 0 - primary_frame_count = video_latent.size(self.dim) - guide_frame_count - primary_frames = video_latent.narrow(self.dim, 0, primary_frame_count) - guide_frames = video_latent.narrow(self.dim, primary_frame_count, guide_frame_count) if guide_frame_count > 0 else None + # Strip guide frames (only from first modality for now) + if extracted_guide_entries is not None: + guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries) + if guide_count > 0: + x = unpacked_latents[0] + latent_count = x.size(self.dim) - guide_count + unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count) + guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count) + guide_entries_list[0] = extracted_guide_entries - if guide_frame_count > 0: - aux_data = {"guide_entries": guide_entries} - else: - aux_data = None - return WindowingContext( - tensor=primary_frames, guide_frames=guide_frames, aux_data=aux_data, - latent_shapes=latent_shapes) + return WindowingState( + latents=unpacked_latents_list, + guide_latents=guide_latents_list, + guide_entries=guide_entries_list, + latent_shapes=latent_shapes, + dim=self.dim, + is_multimodal=is_multimodal) def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: - self._window_data = self._build_window_data(x_in, conds) - video_frames = self._window_data.tensor.size(self.dim) - guide_frames = self._window_data.guide_frames.size(self.dim) if self._window_data.guide_frames is not None else 0 - if video_frames > self.context_length: - if guide_frames > 0: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_frames} guide frames).") - else: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} frames.") + window_state = self._build_window_state(x_in, conds) # build window_state to check frame counts, will be built again in execute + total_frame_count = window_state.latents[0].size(self.dim) + if total_frame_count > self.context_length: + logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.") if self.cond_retain_index_list: logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") return True + logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).") return False def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase: @@ -436,188 +508,121 @@ class IndexListContextHandler(ContextHandlerABC): self._model = model self.set_step(timestep, model_options) - window_data = self._window_data - is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1 - has_guide_frames = window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0 + window_state = self._build_window_state(x_in, conds) + num_modalities = len(window_state.latents) - # if multimodal or has concatenated guide frames on noise latent, use the extended execute path - if is_multimodal or has_guide_frames: - return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, window_data) - - # basic legacy execution path for single-modal video latent with no guide frames concatenated - context_windows = self.get_context_windows(model, x_in, model_options) - enumerated_context_windows = list(enumerate(context_windows)) - - conds_final = [torch.zeros_like(x_in) for _ in conds] - if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] - else: - counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] - biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] - - for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): - callback(self, model, x_in, conds, timestep, model_options) - - for enum_window in enumerated_context_windows: - results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) - for result in results: - self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, - conds_final, counts_final, biases_final) - try: - if self.fuse_method.name == ContextFuseMethods.RELATIVE: - del counts_final - return conds_final - else: - for i in range(len(conds_final)): - conds_final[i] /= counts_final[i] - del counts_final - return conds_final - finally: - for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): - callback(self, model, x_in, conds, timestep, model_options) - - def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, - timestep: torch.Tensor, model_options: dict[str], - window_data: WindowingContext): - """Extended execute path for multimodal models and models with guide frames appended to the noise latent.""" - - latents = self._unpack(x_in, window_data.latent_shapes) - is_multimodal = window_data.latent_shapes is not None and len(window_data.latent_shapes) > 1 - - primary_frames = window_data.tensor - context_windows = self.get_context_windows(model, primary_frames, model_options) + context_windows = self.get_context_windows(model, window_state.latents[0], model_options) enumerated_context_windows = list(enumerate(context_windows)) total_windows = len(enumerated_context_windows) - # Accumulators sized to video portion for primary, full for other modalities - accum_shape_refs = list(latents) - if window_data.guide_frames is not None: - accum_shape_refs[0] = primary_frames - - accum = [[torch.zeros_like(m) for _ in conds] for m in accum_shape_refs] + # Initialize per-modality accumulators (length 1 for single-modality) + accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents] if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs] + counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents] else: - counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs] - biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_shape_refs] + counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents] + biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents] for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) - for window_idx, window in enumerated_context_windows: - comfy.model_management.throw_exception_if_processing_interrupted() - - # Per-modality window indices - if is_multimodal: - map_shapes = window_data.latent_shapes - if primary_frames.size(self.dim) != latents[0].size(self.dim): - map_shapes = list(window_data.latent_shapes) - video_shape = list(window_data.latent_shapes[0]) - video_shape[self.dim] = primary_frames.size(self.dim) - map_shapes[0] = torch.Size(video_shape) - try: - per_modality_indices = model.map_context_window_to_modalities( - window.index_list, map_shapes, self.dim) - except AttributeError: - raise NotImplementedError( - f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.") - modality_windows = {} - for mod_idx in range(1, len(latents)): - modality_windows[mod_idx] = IndexListContextWindow( - per_modality_indices[mod_idx], dim=self.dim, - total_frames=latents[mod_idx].shape[self.dim]) - window = IndexListContextWindow( - window.index_list, dim=self.dim, total_frames=primary_frames.shape[self.dim], - modality_windows=modality_windows) - - # Build per-modality windows list - per_modality_windows_list = [window] - if is_multimodal: - for mod_idx in range(1, len(latents)): - per_modality_windows_list.append(modality_windows[mod_idx]) - - # Slice video, then inject overlapping guide frames if present - sliced_video = per_modality_windows_list[0].get_tensor(primary_frames, retain_index_list=self.cond_retain_index_list) - if window_data.aux_data is not None: - sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data, self.dim) - else: - sliced_primary, num_guide_frames = sliced_video, 0 - logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {primary_frames.shape[self.dim]}" - + (f" (+{num_guide_frames} guide frames)" if num_guide_frames > 0 else "") - + (f" [{len(latents)} modalities]" if is_multimodal else "")) - sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))] - - sub_x, sub_shapes = self._pack(sliced) - - for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): - callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None) - - model_options["transformer_options"]["context_window"] = window - sub_timestep = window.get_tensor(timestep, dim=0) - sub_conds = [self.get_resized_cond(cond, primary_frames, window) for cond in conds] - if is_multimodal: - self._patch_latent_shapes(sub_conds, sub_shapes) - - sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) - - # Unpack output per modality - out_per_modality = [self._unpack(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] - - # Strip guide frames from primary output before accumulation - if num_guide_frames > 0: - window_len = len(window.index_list) - for ci in range(len(sub_conds_out)): - out_per_modality[ci][0] = out_per_modality[ci][0].narrow(self.dim, 0, window_len) - - # Accumulate per modality - for mod_idx in range(len(accum_shape_refs)): - mw = per_modality_windows_list[mod_idx] - sub_conds_out_per_modality = [out_per_modality[ci][mod_idx] for ci in range(len(sub_conds_out))] - self.combine_context_window_results( - accum_shape_refs[mod_idx], sub_conds_out_per_modality, sub_conds, mw, - window_idx, total_windows, timestep, - accum[mod_idx], counts[mod_idx], biases[mod_idx]) + # accumulate results from each context window + for enum_window in enumerated_context_windows: + results = self.evaluate_context_windows( + calc_cond_batch, model, x_in, conds, timestep, [enum_window], + model_options, window_state=window_state, total_windows=total_windows) + for result in results: + # result.sub_conds_out is per-cond, per-modality: list[list[Tensor]] + for mod_idx in range(num_modalities): + mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))] + modality_window = result.window.get_window_for_modality(mod_idx) + self.combine_context_window_results( + window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window, + result.window_idx, total_windows, timestep, + accum[mod_idx], counts[mod_idx], biases[mod_idx]) + # fuse accumulated results into final conds try: - result = [] + result_out = [] for ci in range(len(conds)): finalized = [] - for mod_idx in range(len(accum_shape_refs)): + for mod_idx in range(num_modalities): if self.fuse_method.name != ContextFuseMethods.RELATIVE: accum[mod_idx][ci] /= counts[mod_idx][ci] f = accum[mod_idx][ci] - if mod_idx == 0 and window_data.guide_frames is not None: - f = torch.cat([f, window_data.guide_frames], dim=self.dim) + + # if guide frames were injected, append them to the end of the fused latents for the next step + if window_state.guide_latents[mod_idx] is not None: + f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim) finalized.append(f) - packed, _ = self._pack(finalized) - result.append(packed) - return result + + # pack modalities together if needed + if window_state.is_multimodal and len(finalized) > 1: + packed, _ = comfy.utils.pack_latents(finalized) + else: + packed = finalized[0] + + result_out.append(packed) + return result_out finally: for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) - def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], - model_options, device=None, first_device=None): + def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, + timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], + model_options, window_state: WindowingState, total_windows: int = None, + device=None, first_device=None): + """Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out + + For each window: + 1. Builds windows (for each modality if multimodal) + 2. Slices window for each modality + 3. Injects concatenated latent guide frames where present + 4. Packs together if needed and calls model + 5. Unpacks and strips any guides from outputs + """ + x = window_state.latents[0] + results: list[ContextResults] = [] for window_idx, window in enumerated_context_windows: # allow processing to end between context window executions for faster Cancel comfy.model_management.throw_exception_if_processing_interrupted() + # prepare the window accounting for multimodal windows + window = window_state.prepare_window(window, model) + + # slice the window for each modality, injecting guide frames where applicable + sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.cond_retain_index_list, device) + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) - # update exposed params - model_options["transformer_options"]["context_window"] = window - # get subsections of x, timestep, conds - sub_x = window.get_tensor(x_in, device) - sub_timestep = window.get_tensor(timestep, device, dim=0) - sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds] + logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}" + + (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "") + ) + # if multimodal, pack modalities together + if window_state.is_multimodal and len(sliced) > 1: + sub_x, sub_shapes = comfy.utils.pack_latents(sliced) + else: + sub_x, sub_shapes = sliced[0], [sliced[0].shape] + + # get resized conds for window + model_options["transformer_options"]["context_window"] = window + sub_timestep = window.get_tensor(timestep, dim=0) + sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds] + + # if multimodal, patch latent_shapes in conds for correct unpacking in model + window_state.patch_latent_shapes(sub_conds, sub_shapes) + + # call model on window sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) - if device is not None: - for i in range(len(sub_conds_out)): - sub_conds_out[i] = sub_conds_out[i].to(x_in.device) - results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) + + # unpack outputs and strip guide frames + out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] + window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window) + + results.append(ContextResults(window_idx, out_per_modality, sub_conds, window)) return results @@ -684,28 +689,11 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois if not handler.freenoise: return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - # For packed multimodal tensors (e.g. LTXAV), noise is [B, 1, flat] and FreeNoise - # must only shuffle the video portion. Unpack, apply to video, repack. - latent_shapes = IndexListContextHandler._get_latent_shapes( - [guider.conds.get('positive', guider.conds.get('negative', []))]) - - if latent_shapes is not None and len(latent_shapes) > 1: - modalities = comfy.utils.unpack_latents(noise, latent_shapes) - video_total = latent_shapes[0][handler.dim] - modalities[0] = apply_freenoise(modalities[0], handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) - for i in range(1, len(modalities)): - mod_total = latent_shapes[i][handler.dim] - ratio = mod_total / video_total if video_total > 0 else 1 - mod_ctx_len = max(round(handler.context_length * ratio), 1) - mod_ctx_overlap = max(round(handler.context_overlap * ratio), 0) - modalities[i] = apply_freenoise(modalities[i], handler.dim, mod_ctx_len, mod_ctx_overlap, extra_args["seed"]) - noise, _ = comfy.utils.pack_latents(modalities) - else: - noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + conds = [guider.conds.get('positive', guider.conds.get('negative', []))] + noise = handler._apply_freenoise(noise, conds, extra_args["seed"]) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - def create_sampler_sample_wrapper(model: ModelPatcher): model.add_wrapper_with_key( comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, @@ -713,7 +701,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher): _sampler_sample_wrapper ) - def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: total_dims = len(x_in.shape) weights_tensor = torch.Tensor(weights).to(device=device) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8960ecd19..08f8b058d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1162,7 +1162,7 @@ class LTXAV(BaseModel): # Adjust spatial end positions for dilated (downscaled) guides. # Each guide entry may have a different downscale factor; expand the # per-entry factor to cover all tokens belonging to that entry. - downscale_factors = getattr(window, 'guide_downscale_factors', []) + downscale_factors = window.guide_downscale_factors overlap_info = window.guide_overlap_info if downscale_factors: per_token_factor = []