diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 91d019a00..357fbae17 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -358,11 +358,9 @@ class IndexListContextHandler(ContextHandlerABC): for window_idx, window in enumerated_context_windows: comfy.model_management.throw_exception_if_processing_interrupted() - - # Attach guide info to window for resize_cond_for_context_window - window.guide_count = guide_count - if guide_suffix is not None: - window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4]) + logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}" + + (f" (+{guide_count} guide)" if guide_count > 0 else "") + + (f" [{len(modalities)} modalities]" if is_multimodal else "")) # Per-modality window indices if is_multimodal: @@ -384,9 +382,6 @@ class IndexListContextHandler(ContextHandlerABC): window = IndexListContextWindow( window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim], modality_windows=modality_windows) - window.guide_count = guide_count - if guide_suffix is not None: - window.guide_spatial = (guide_suffix.shape[3], guide_suffix.shape[4]) else: per_mod_indices = [window.index_list] diff --git a/comfy/model_base.py b/comfy/model_base.py index 12d49305f..9c31e2651 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -303,6 +303,45 @@ class BaseModel(torch.nn.Module): Override in subclasses that concatenate guide reference frames to the latent.""" return 0 + def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + """Resize guide-related conditioning for context windows. + Derives guide_count from denoise_mask/x_in size difference. + Derives spatial dims from x_in. Requires self.diffusion_model.patchifier.""" + if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + cond_tensor = cond_value.cond + guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) + if guide_count > 0: + T_video = x_in.size(window.dim) + video_mask = cond_tensor.narrow(window.dim, 0, T_video) + guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) + sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) + sliced_guide = window.get_tensor(guide_mask, device) + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + + if cond_key == "keyframe_idxs": + window_len = len(window.index_list) + H, W = x_in.shape[3], x_in.shape[4] + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + + if cond_key == "guide_attention_entries": + window_len = len(window.index_list) + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [{**e, "pre_filter_count": window_len * H * W, + "latent_shape": [window_len, H, W]} for e in cond_value.cond] + return cond_value._copy_with(new_entries) + + return None + def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -1038,44 +1077,7 @@ class LTXV(BaseModel): return 0 def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - guide_count = getattr(window, 'guide_count', 0) - - if cond_key == "denoise_mask" and guide_count > 0: - # Slice both video and guide halves with same window indices - cond_tensor = cond_value.cond - T_video = cond_tensor.size(window.dim) - guide_count - video_mask = cond_tensor.narrow(window.dim, 0, T_video) - guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) - sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) - sliced_guide = window.get_tensor(guide_mask, device) - return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) - - if cond_key == "keyframe_idxs" and guide_count > 0: - # Recompute coords for window_len frames so guide tokens are co-located - # with noise tokens in RoPE space (identical to a standalone short video) - window_len = len(window.index_list) - H, W = window.guide_spatial - patchifier = self.diffusion_model.patchifier - latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.diffusion_model.vae_scale_factors, - causal_fix=self.diffusion_model.causal_temporal_positioning) - B = cond_value.cond.shape[0] - if B > 1: - pixel_coords = pixel_coords.expand(B, -1, -1, -1) - return cond_value._copy_with(pixel_coords) - - if cond_key == "guide_attention_entries" and guide_count > 0: - # Adjust token counts for window size - window_len = len(window.index_list) - H, W = window.guide_spatial - new_entries = [{**e, "pre_filter_count": window_len * H * W, - "latent_shape": [window_len, H, W]} for e in cond_value.cond] - return cond_value._copy_with(new_entries) - - return None + return self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list) class LTXAV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): @@ -1083,7 +1085,6 @@ class LTXAV(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) - logging.info(f"LTXAV.extra_conds: guide_attention_entries={'guide_attention_entries' in kwargs}, keyframe_idxs={'keyframe_idxs' in kwargs}") attention_mask = kwargs.get("attention_mask", None) device = kwargs["device"] @@ -1170,12 +1171,8 @@ class LTXAV(BaseModel): for cond_dict in cond_list: model_conds = cond_dict.get('model_conds', {}) gae = model_conds.get('guide_attention_entries') - logging.info(f"LTXAV.get_guide_frame_count: keys={list(model_conds.keys())}, gae={gae is not None}") if gae is not None and hasattr(gae, 'cond') and gae.cond: - count = sum(e["latent_shape"][0] for e in gae.cond) - logging.info(f"LTXAV.get_guide_frame_count: found {count} guide frames") - return count - logging.info("LTXAV.get_guide_frame_count: no guide frames found") + return sum(e["latent_shape"][0] for e in gae.cond) return 0 def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): @@ -1186,41 +1183,10 @@ class LTXAV(BaseModel): sliced = audio_window.get_tensor(cond_value.cond, device, dim=2) return cond_value._copy_with(sliced) - # Guide handling (same as LTXV — shared guide mechanism) - guide_count = getattr(window, 'guide_count', 0) - if cond_key in ("keyframe_idxs", "guide_attention_entries", "denoise_mask"): - logging.info(f"LTXAV resize_cond: {cond_key}, guide_count={guide_count}, has_spatial={hasattr(window, 'guide_spatial')}") - - if cond_key == "denoise_mask" and guide_count > 0: - cond_tensor = cond_value.cond - T_video = cond_tensor.size(window.dim) - guide_count - video_mask = cond_tensor.narrow(window.dim, 0, T_video) - guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) - sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) - sliced_guide = window.get_tensor(guide_mask, device) - return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) - - if cond_key == "keyframe_idxs" and guide_count > 0: - window_len = len(window.index_list) - H, W = window.guide_spatial - patchifier = self.diffusion_model.patchifier - latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.diffusion_model.vae_scale_factors, - causal_fix=self.diffusion_model.causal_temporal_positioning) - B = cond_value.cond.shape[0] - if B > 1: - pixel_coords = pixel_coords.expand(B, -1, -1, -1) - return cond_value._copy_with(pixel_coords) - - if cond_key == "guide_attention_entries" and guide_count > 0: - window_len = len(window.index_list) - H, W = window.guide_spatial - new_entries = [{**e, "pre_filter_count": window_len * H * W, - "latent_shape": [window_len, H, W]} for e in cond_value.cond] - return cond_value._copy_with(new_entries) + # Guide handling (shared with LTXV) + result = self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list) + if result is not None: + return result return None