diff --git a/comfy/context_windows.py b/comfy/context_windows.py index a4c49bbee..fe1afdffe 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -147,8 +147,7 @@ def _compute_guide_overlap(guide_entries, window_index_list): guide_entries: list of guide_attention_entry dicts (must have 'latent_start' and 'latent_shape') window_index_list: the window's frame indices into the video portion - Returns None if any entry lacks 'latent_start' (backward compat → legacy path). - Otherwise returns (suffix_indices, overlap_info, kf_local_positions, total_overlap): + Returns (suffix_indices, overlap_info, kf_local_positions, total_overlap): suffix_indices: indices into the guide_suffix tensor for frame selection overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment kf_local_positions: window-local frame positions for keyframe_idxs regeneration @@ -164,7 +163,7 @@ def _compute_guide_overlap(guide_entries, window_index_list): for entry_idx, entry in enumerate(guide_entries): latent_start = entry.get("latent_start", None) if latent_start is None: - return None + raise ValueError("guide_attention_entry missing required 'latent_start'.") guide_len = entry["latent_shape"][0] entry_overlap = 0 @@ -452,11 +451,7 @@ class IndexListContextHandler(ContextHandlerABC): num_guide_in_window = 0 if guide_suffix is not None and guide_entries is not None: overlap = _compute_guide_overlap(guide_entries, window.index_list) - if overlap is None: - # Legacy: no latent_start → equal-size assumption - sliced_guide = mod_windows[0].get_tensor(guide_suffix) - num_guide_in_window = sliced_guide.shape[self.dim] - elif overlap[3] > 0: + if overlap[3] > 0: suffix_idx, overlap_info, kf_local_pos, num_guide_in_window = overlap idx = tuple([slice(None)] * self.dim + [suffix_idx]) sliced_guide = guide_suffix[idx] diff --git a/comfy/model_base.py b/comfy/model_base.py index 893beb85a..e4659a236 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -305,8 +305,8 @@ class BaseModel(torch.nn.Module): def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): """Resize guide-related conditioning for context windows. - Uses overlap info from window if available (generalized path), - otherwise falls back to legacy equal-size assumption.""" + Requires guide_suffix_indices, guide_overlap_info, and guide_kf_local_positions + to be set on the window by _compute_guide_overlap.""" if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): cond_tensor = cond_value.cond guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) @@ -315,76 +315,46 @@ class BaseModel(torch.nn.Module): video_mask = cond_tensor.narrow(window.dim, 0, T_video) guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) - # Use overlap-based guide selection if available, otherwise legacy - suffix_indices = getattr(window, 'guide_suffix_indices', None) - if suffix_indices is not None: + suffix_indices = window.guide_suffix_indices + if suffix_indices: idx = tuple([slice(None)] * window.dim + [suffix_indices]) - sliced_guide = guide_mask[idx].to(device) if suffix_indices else None - else: - sliced_guide = window.get_tensor(guide_mask, device) - if sliced_guide is not None and sliced_guide.shape[window.dim] > 0: + sliced_guide = guide_mask[idx].to(device) return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) else: return cond_value._copy_with(sliced_video) if cond_key == "keyframe_idxs": - kf_local_pos = getattr(window, 'guide_kf_local_positions', None) - if kf_local_pos is not None: - # Generalized: regenerate coords for full window, select guide positions - if not kf_local_pos: - return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty - H, W = x_in.shape[3], x_in.shape[4] - window_len = len(window.index_list) - patchifier = self.diffusion_model.patchifier - latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.diffusion_model.vae_scale_factors, - causal_fix=self.diffusion_model.causal_temporal_positioning) - tokens = [] - for pos in kf_local_pos: - tokens.extend(range(pos * H * W, (pos + 1) * H * W)) - pixel_coords = pixel_coords[:, :, tokens, :] - B = cond_value.cond.shape[0] - if B > 1: - pixel_coords = pixel_coords.expand(B, -1, -1, -1) - return cond_value._copy_with(pixel_coords) - else: - # Legacy: regenerate for window_len (equal-size assumption) - window_len = len(window.index_list) - H, W = x_in.shape[3], x_in.shape[4] - patchifier = self.diffusion_model.patchifier - latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) - from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords - pixel_coords = latent_to_pixel_coords( - latent_coords, - self.diffusion_model.vae_scale_factors, - causal_fix=self.diffusion_model.causal_temporal_positioning) - B = cond_value.cond.shape[0] - if B > 1: - pixel_coords = pixel_coords.expand(B, -1, -1, -1) - return cond_value._copy_with(pixel_coords) + kf_local_pos = window.guide_kf_local_positions + if not kf_local_pos: + return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty + H, W = x_in.shape[3], x_in.shape[4] + window_len = len(window.index_list) + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + pixel_coords = latent_to_pixel_coords( + latent_coords, + self.diffusion_model.vae_scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + tokens = [] + for pos in kf_local_pos: + tokens.extend(range(pos * H * W, (pos + 1) * H * W)) + pixel_coords = pixel_coords[:, :, tokens, :] + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) if cond_key == "guide_attention_entries": - overlap_info = getattr(window, 'guide_overlap_info', None) - if overlap_info is not None: - # Generalized: per-guide adjustment based on overlap - H, W = x_in.shape[3], x_in.shape[4] - new_entries = [] - for entry_idx, overlap_count in overlap_info: - e = cond_value.cond[entry_idx] - new_entries.append({**e, - "pre_filter_count": overlap_count * H * W, - "latent_shape": [overlap_count, H, W]}) - return cond_value._copy_with(new_entries) - else: - # Legacy: all entries adjusted to window_len - window_len = len(window.index_list) - H, W = x_in.shape[3], x_in.shape[4] - new_entries = [{**e, "pre_filter_count": window_len * H * W, - "latent_shape": [window_len, H, W]} for e in cond_value.cond] - return cond_value._copy_with(new_entries) + overlap_info = window.guide_overlap_info + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [] + for entry_idx, overlap_count in overlap_info: + e = cond_value.cond[entry_idx] + new_entries.append({**e, + "pre_filter_count": overlap_count * H * W, + "latent_shape": [overlap_count, H, W]}) + return cond_value._copy_with(new_entries) return None