From f1acd5bd858b732943fea784f21bbbf7afd98ce7 Mon Sep 17 00:00:00 2001 From: ozbayb <17261091+ozbayb@users.noreply.github.com> Date: Tue, 7 Apr 2026 12:43:41 -0600 Subject: [PATCH] LTX2 context windows - Cleanup: Simplify window data handling, improve variable names, refactor and condense new context window methods to separate execution paths cleanly --- comfy/context_windows.py | 242 ++++++++++++++++++++++++--------------- comfy/model_base.py | 86 ++++++-------- 2 files changed, 183 insertions(+), 145 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index e89c9cee2..2ec927f3e 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -140,15 +140,15 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d return cond_value._copy_with(sliced) -def _compute_guide_overlap(guide_entries, window_index_list): - """Compute which guide frames overlap with a context window. +def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int]): + """Compute which concatenated guide frames overlap with a context window. Args: - guide_entries: list of guide_attention_entry dicts (must have 'latent_start' and 'latent_shape') + guide_entries: list of guide_attention_entry dicts window_index_list: the window's frame indices into the video portion - Returns (suffix_indices, overlap_info, kf_local_positions, total_overlap): - suffix_indices: indices into the guide_suffix tensor for frame selection + Returns: + suffix_indices: indices into the guide_frames tensor for frame selection overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment kf_local_positions: window-local frame positions for keyframe_idxs regeneration total_overlap: total number of overlapping guide frames @@ -181,11 +181,37 @@ def _compute_guide_overlap(guide_entries, window_index_list): return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) +def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWindowABC, + aux_data: dict, dim: int) -> tuple[torch.Tensor, int]: + """Inject overlapping guide frames into a context window slice. + + Uses aux_data from WindowingContext to determine which guide frames overlap + with this window's indices, concatenates them onto the video slice, and sets + window attributes for downstream conditioning resize. + + Returns (augmented_slice, num_guide_frames_added). + """ + guide_entries = aux_data["guide_entries"] + guide_frames = aux_data["guide_frames"] + overlap = compute_guide_overlap(guide_entries, window.index_list) + suffix_idx, overlap_info, kf_local_pos, guide_frame_count = overlap + window.guide_frames_indices = suffix_idx + window.guide_overlap_info = overlap_info + window.guide_kf_local_positions = kf_local_pos + if guide_frame_count > 0: + idx = tuple([slice(None)] * dim + [suffix_idx]) + sliced_guide = guide_frames[idx] + return torch.cat([video_slice, sliced_guide], dim=dim), guide_frame_count + return video_slice, 0 + + @dataclass class WindowingContext: tensor: torch.Tensor - suffix: torch.Tensor | None + guide_frames: torch.Tensor | None aux_data: Any + latent_shapes: list | None + is_multimodal: bool @dataclass class ContextSchedule: @@ -215,8 +241,8 @@ class IndexListContextHandler(ContextHandlerABC): self.callbacks = {} - def _get_latent_shapes(self, conds): - """Extract latent_shapes from conditioning. Returns None if absent.""" + @staticmethod + def _get_latent_shapes(conds): for cond_list in conds: if cond_list is None: continue @@ -226,20 +252,20 @@ class IndexListContextHandler(ContextHandlerABC): return model_conds['latent_shapes'].cond return None - def _decompose(self, x, latent_shapes): - """Packed tensor -> list of per-modality tensors.""" + @staticmethod + def _unpack(combined_latent, latent_shapes): if latent_shapes is not None and len(latent_shapes) > 1: - return comfy.utils.unpack_latents(x, latent_shapes) - return [x] + return comfy.utils.unpack_latents(combined_latent, latent_shapes) + return [combined_latent] - def _compose(self, modalities): - """List of per-modality tensors -> single tensor for pipeline.""" - if len(modalities) > 1: - return comfy.utils.pack_latents(modalities) - return modalities[0], [modalities[0].shape] + @staticmethod + def _pack(latents): + if len(latents) > 1: + return comfy.utils.pack_latents(latents) + return latents[0], [latents[0].shape] - def _patch_latent_shapes(self, sub_conds, new_shapes): - """Patch latent_shapes CONDConstant in (already-copied) sub_conds.""" + @staticmethod + def _patch_latent_shapes(sub_conds, new_shapes): for cond_list in sub_conds: if cond_list is None: continue @@ -248,14 +274,48 @@ class IndexListContextHandler(ContextHandlerABC): if 'latent_shapes' in model_conds: model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) - def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: + def _build_window_data(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingContext: latent_shapes = self._get_latent_shapes(conds) - primary = self._decompose(x_in, latent_shapes)[0] - guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0 - video_frames = primary.size(self.dim) - guide_count + is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 + if is_multimodal: + video_latent = comfy.utils.unpack_latents(x_in, latent_shapes)[0] + else: + video_latent = x_in + + guide_entries = None + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + guide_entries = entries.cond + break + if guide_entries is not None: + break + + guide_frame_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries is not None else 0 + primary_frame_count = video_latent.size(self.dim) - guide_frame_count + primary_frames = video_latent.narrow(self.dim, 0, primary_frame_count) + guide_frames = video_latent.narrow(self.dim, primary_frame_count, guide_frame_count) if guide_frame_count > 0 else None + + if guide_frame_count > 0: + aux_data = {"guide_entries": guide_entries, "guide_frames": guide_frames} + else: + aux_data = None + + return WindowingContext( + tensor=primary_frames, guide_frames=guide_frames, aux_data=aux_data, + latent_shapes=latent_shapes, is_multimodal=is_multimodal) + + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: + self._window_data = self._build_window_data(x_in, conds) + video_frames = self._window_data.tensor.size(self.dim) + guide_frames = self._window_data.guide_frames.size(self.dim) if self._window_data.guide_frames is not None else 0 if video_frames > self.context_length: - if guide_count > 0: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_count} guide frames excluded).") + if guide_frames > 0: + logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_frames} guide frames).") else: logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} frames.") if self.cond_retain_index_list: @@ -367,15 +427,9 @@ class IndexListContextHandler(ContextHandlerABC): self._model = model self.set_step(timestep, model_options) - # Check if multimodal or model has auxiliary frames requiring the extended path - latent_shapes = self._get_latent_shapes(conds) - is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 - if is_multimodal: - return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, latent_shapes) - window_data = model.prepare_for_windowing(x_in, conds, self.dim) - if window_data.suffix is not None or window_data.aux_data is not None: - return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, - latent_shapes, window_data) + window_data = self._window_data + if window_data.is_multimodal or (window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0): + return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, window_data) context_windows = self.get_context_windows(model, x_in, model_options) enumerated_context_windows = list(enumerate(context_windows)) @@ -410,101 +464,104 @@ class IndexListContextHandler(ContextHandlerABC): def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str], - latent_shapes, window_data: WindowingContext=None): - """Extended execute path for multimodal models and models with auxiliary frames.""" - modalities = self._decompose(x_in, latent_shapes) - is_multimodal = len(modalities) > 1 + window_data: WindowingContext): + """Extended execute path for multimodal models and models with guide frames appended to the noise latent.""" + latents = self._unpack(x_in, window_data.latent_shapes) + is_multimodal = window_data.is_multimodal - if window_data is None: - window_data = model.prepare_for_windowing(modalities[0], conds, self.dim) + primary_frames = window_data.tensor + num_guide_frames = window_data.guide_frames.size(self.dim) if window_data.guide_frames is not None else 0 - video_primary = window_data.tensor - aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0 - - context_windows = self.get_context_windows(model, video_primary, model_options) + context_windows = self.get_context_windows(model, primary_frames, model_options) enumerated_context_windows = list(enumerate(context_windows)) total_windows = len(enumerated_context_windows) # Accumulators sized to video portion for primary, full for other modalities - accum_modalities = list(modalities) - if window_data.suffix is not None: - accum_modalities[0] = video_primary + accum_shape_refs = list(latents) + if window_data.guide_frames is not None: + accum_shape_refs[0] = primary_frames - accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities] + accum = [[torch.zeros_like(m) for _ in conds] for m in accum_shape_refs] if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities] + counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs] else: - counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities] - biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities] + counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs] + biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_shape_refs] for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) for window_idx, window in enumerated_context_windows: comfy.model_management.throw_exception_if_processing_interrupted() - logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}" - + (f" (+{aux_count} aux)" if aux_count > 0 else "") - + (f" [{len(modalities)} modalities]" if is_multimodal else "")) + logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {primary_frames.shape[self.dim]}" + + (f" (+{num_guide_frames} guide frames)" if num_guide_frames > 0 else "") + + (f" [{len(latents)} modalities]" if is_multimodal else "")) # Per-modality window indices if is_multimodal: - map_shapes = latent_shapes - if video_primary.size(self.dim) != modalities[0].size(self.dim): - map_shapes = list(latent_shapes) - video_shape = list(latent_shapes[0]) - video_shape[self.dim] = video_primary.size(self.dim) + map_shapes = window_data.latent_shapes + if primary_frames.size(self.dim) != latents[0].size(self.dim): + map_shapes = list(window_data.latent_shapes) + video_shape = list(window_data.latent_shapes[0]) + video_shape[self.dim] = primary_frames.size(self.dim) map_shapes[0] = torch.Size(video_shape) - per_mod_indices = model.map_context_window_to_modalities( - window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list] + try: + per_modality_indices = model.map_context_window_to_modalities( + window.index_list, map_shapes, self.dim) + except AttributeError: + raise NotImplementedError( + f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.") modality_windows = {} - for mod_idx in range(1, len(modalities)): + for mod_idx in range(1, len(latents)): modality_windows[mod_idx] = IndexListContextWindow( - per_mod_indices[mod_idx], dim=self.dim, - total_frames=modalities[mod_idx].shape[self.dim]) + per_modality_indices[mod_idx], dim=self.dim, + total_frames=latents[mod_idx].shape[self.dim]) window = IndexListContextWindow( - window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim], + window.index_list, dim=self.dim, total_frames=primary_frames.shape[self.dim], modality_windows=modality_windows) # Build per-modality windows list - mod_windows = [window] + per_modality_windows_list = [window] if is_multimodal: - for mod_idx in range(1, len(modalities)): - mod_windows.append(modality_windows[mod_idx]) + for mod_idx in range(1, len(latents)): + per_modality_windows_list.append(modality_windows[mod_idx]) - # Slice video, then let model inject auxiliary frames - sliced_video = mod_windows[0].get_tensor(video_primary, retain_index_list=self.cond_retain_index_list) - sliced_primary, num_aux = model.prepare_window_input( - sliced_video, window, window_data.aux_data, self.dim) - sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))] + # Slice video, then inject overlapping guide frames if present + sliced_video = per_modality_windows_list[0].get_tensor(primary_frames, retain_index_list=self.cond_retain_index_list) + if window_data.aux_data is not None: + sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data.aux_data, self.dim) + else: + sliced_primary, num_guide_frames = sliced_video, 0 + sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))] - sub_x, sub_shapes = self._compose(sliced) + sub_x, sub_shapes = self._pack(sliced) for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None) model_options["transformer_options"]["context_window"] = window sub_timestep = window.get_tensor(timestep, dim=0) - sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds] + sub_conds = [self.get_resized_cond(cond, primary_frames, window) for cond in conds] if is_multimodal: self._patch_latent_shapes(sub_conds, sub_shapes) sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) - # Decompose output per modality - out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] + # Unpack output per modality + out_per_modality = [self._unpack(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] - # Strip auxiliary frames from primary output before accumulation - if num_aux > 0: + # Strip guide frames from primary output before accumulation + if num_guide_frames > 0: window_len = len(window.index_list) for ci in range(len(sub_conds_out)): - out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len) + out_per_modality[ci][0] = out_per_modality[ci][0].narrow(self.dim, 0, window_len) # Accumulate per modality - for mod_idx in range(len(accum_modalities)): - mw = mod_windows[mod_idx] - mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))] + for mod_idx in range(len(accum_shape_refs)): + mw = per_modality_windows_list[mod_idx] + sub_conds_out_per_modality = [out_per_modality[ci][mod_idx] for ci in range(len(sub_conds_out))] self.combine_context_window_results( - accum_modalities[mod_idx], mod_sub_out, sub_conds, mw, + accum_shape_refs[mod_idx], sub_conds_out_per_modality, sub_conds, mw, window_idx, total_windows, timestep, accum[mod_idx], counts[mod_idx], biases[mod_idx]) @@ -512,15 +569,15 @@ class IndexListContextHandler(ContextHandlerABC): result = [] for ci in range(len(conds)): finalized = [] - for mod_idx in range(len(accum_modalities)): + for mod_idx in range(len(accum_shape_refs)): if self.fuse_method.name != ContextFuseMethods.RELATIVE: accum[mod_idx][ci] /= counts[mod_idx][ci] f = accum[mod_idx][ci] - if mod_idx == 0 and window_data.suffix is not None: - f = torch.cat([f, window_data.suffix], dim=self.dim) + if mod_idx == 0 and window_data.guide_frames is not None: + f = torch.cat([f, window_data.guide_frames], dim=self.dim) finalized.append(f) - composed, _ = self._compose(finalized) - result.append(composed) + packed, _ = self._pack(finalized) + result.append(packed) return result finally: for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): @@ -616,11 +673,8 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois # For packed multimodal tensors (e.g. LTXAV), noise is [B, 1, flat] and FreeNoise # must only shuffle the video portion. Unpack, apply to video, repack. - latent_shapes = None - try: - latent_shapes = guider.conds['positive'][0]['model_conds']['latent_shapes'].cond - except (KeyError, IndexError, AttributeError): - pass + latent_shapes = IndexListContextHandler._get_latent_shapes( + [guider.conds.get('positive', guider.conds.get('negative', []))]) if latent_shapes is not None and len(latent_shapes) > 1: modalities = comfy.utils.unpack_latents(noise, latent_shapes) diff --git a/comfy/model_base.py b/comfy/model_base.py index c1e2ef8c2..65ce1bac5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -287,12 +287,6 @@ class BaseModel(torch.nn.Module): return data return None - def prepare_for_windowing(self, primary, conds, dim): - return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None) - - def prepare_window_input(self, video_slice, window, aux_data, dim): - return video_slice, 0 - def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): """Override in subclasses to handle model-specific cond slicing for context windows. Return a sliced cond object, or None to fall through to default handling. @@ -1098,7 +1092,7 @@ class LTXAV(BaseModel): for i in range(1, len(latent_shapes)): mod_total = latent_shapes[i][dim] - # Length proportional to video window frame count (not index span) + # Length proportional to video window frame count mod_window_len = max(round(video_window_len * mod_total / video_total), 1) # Anchor to end of video range v_end = max(primary_indices) + 1 @@ -1108,17 +1102,6 @@ class LTXAV(BaseModel): return result - def get_guide_frame_count(self, x, conds): - for cond_list in conds: - if cond_list is None: - continue - for cond_dict in cond_list: - model_conds = cond_dict.get('model_conds', {}) - gae = model_conds.get('guide_attention_entries') - if gae is not None and hasattr(gae, 'cond') and gae.cond: - return sum(e["latent_shape"][0] for e in gae.cond) - return 0 - @staticmethod def _get_guide_entries(conds): for cond_list in conds: @@ -1126,43 +1109,27 @@ class LTXAV(BaseModel): continue for cond_dict in cond_list: model_conds = cond_dict.get('model_conds', {}) - gae = model_conds.get('guide_attention_entries') - if gae is not None and hasattr(gae, 'cond') and gae.cond: - return gae.cond + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + return entries.cond return None - - def prepare_for_windowing(self, primary, conds, dim): - guide_count = self.get_guide_frame_count(primary, conds) + + def prepare_window_data(self, x_in, conds, dim, window_data): + primary = comfy.utils.unpack_latents(x_in, window_data.latent_shapes)[0] if window_data.is_multimodal else x_in + guide_entries = self._get_guide_entries(conds) + guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0 if guide_count <= 0: - return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None) + return comfy.context_windows.WindowingContext( + tensor=primary, guide_frames=None, aux_data=None, + latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal) video_len = primary.size(dim) - guide_count video_primary = primary.narrow(dim, 0, video_len) - guide_suffix = primary.narrow(dim, video_len, guide_count) - guide_entries = self._get_guide_entries(conds) + guide_frames = primary.narrow(dim, video_len, guide_count) return comfy.context_windows.WindowingContext( - tensor=video_primary, suffix=guide_suffix, - aux_data={"guide_entries": guide_entries, "guide_suffix": guide_suffix}) + tensor=video_primary, guide_frames=guide_frames, + aux_data={"guide_entries": guide_entries, "guide_frames": guide_frames}, + latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal) - def prepare_window_input(self, video_slice, window, aux_data, dim): - if aux_data is None: - return video_slice, 0 - guide_entries = aux_data["guide_entries"] - guide_suffix = aux_data["guide_suffix"] - if guide_entries is None: - window.guide_suffix_indices = [] - window.guide_overlap_info = [] - window.guide_kf_local_positions = [] - return video_slice, 0 - overlap = comfy.context_windows._compute_guide_overlap(guide_entries, window.index_list) - suffix_idx, overlap_info, kf_local_pos, num_guide = overlap - window.guide_suffix_indices = suffix_idx - window.guide_overlap_info = overlap_info - window.guide_kf_local_positions = kf_local_pos - if num_guide > 0: - idx = tuple([slice(None)] * dim + [suffix_idx]) - sliced_guide = guide_suffix[idx] - return torch.cat([video_slice, sliced_guide], dim=dim), num_guide - return video_slice, 0 def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): # Audio denoise mask — slice using audio modality window @@ -1181,7 +1148,7 @@ class LTXAV(BaseModel): video_mask = cond_tensor.narrow(window.dim, 0, T_video) guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) - suffix_indices = window.guide_suffix_indices + suffix_indices = window.guide_frames_indices if suffix_indices: idx = tuple([slice(None)] * window.dim + [suffix_indices]) sliced_guide = guide_mask[idx].to(device) @@ -1199,14 +1166,31 @@ class LTXAV(BaseModel): patchifier = self.diffusion_model.patchifier latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords + scale_factors = self.diffusion_model.vae_scale_factors pixel_coords = latent_to_pixel_coords( latent_coords, - self.diffusion_model.vae_scale_factors, + scale_factors, causal_fix=self.diffusion_model.causal_temporal_positioning) tokens = [] for pos in kf_local_pos: tokens.extend(range(pos * H * W, (pos + 1) * H * W)) pixel_coords = pixel_coords[:, :, tokens, :] + + # Adjust spatial end positions for dilated (downscaled) guides. + # Each guide entry may have a different downscale factor; expand the + # per-entry factor to cover all tokens belonging to that entry. + downscale_factors = getattr(window, 'guide_downscale_factors', []) + overlap_info = window.guide_overlap_info + if downscale_factors: + per_token_factor = [] + for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors): + per_token_factor.extend([dsf] * (overlap_count * H * W)) + factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype) + spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor( + scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype, + ).view(1, -1, 1, 1) + pixel_coords[:, 1:, :, 1:] += spatial_end_offset + B = cond_value.cond.shape[0] if B > 1: pixel_coords = pixel_coords.expand(B, -1, -1, -1)