LTX2 context windows - Cleanup: Simplify window data handling, improve variable names, refactor and condense new context window methods to separate execution paths cleanly

This commit is contained in:
ozbayb 2026-04-07 12:43:41 -06:00
parent 3a061f4bbf
commit f1acd5bd85
2 changed files with 183 additions and 145 deletions

View File

@ -140,15 +140,15 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
return cond_value._copy_with(sliced)
def _compute_guide_overlap(guide_entries, window_index_list):
"""Compute which guide frames overlap with a context window.
def compute_guide_overlap(guide_entries: list[dict], window_index_list: list[int]):
"""Compute which concatenated guide frames overlap with a context window.
Args:
guide_entries: list of guide_attention_entry dicts (must have 'latent_start' and 'latent_shape')
guide_entries: list of guide_attention_entry dicts
window_index_list: the window's frame indices into the video portion
Returns (suffix_indices, overlap_info, kf_local_positions, total_overlap):
suffix_indices: indices into the guide_suffix tensor for frame selection
Returns:
suffix_indices: indices into the guide_frames tensor for frame selection
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
total_overlap: total number of overlapping guide frames
@ -181,11 +181,37 @@ def _compute_guide_overlap(guide_entries, window_index_list):
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
def inject_guide_frames_into_window(video_slice: torch.Tensor, window: ContextWindowABC,
aux_data: dict, dim: int) -> tuple[torch.Tensor, int]:
"""Inject overlapping guide frames into a context window slice.
Uses aux_data from WindowingContext to determine which guide frames overlap
with this window's indices, concatenates them onto the video slice, and sets
window attributes for downstream conditioning resize.
Returns (augmented_slice, num_guide_frames_added).
"""
guide_entries = aux_data["guide_entries"]
guide_frames = aux_data["guide_frames"]
overlap = compute_guide_overlap(guide_entries, window.index_list)
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = overlap
window.guide_frames_indices = suffix_idx
window.guide_overlap_info = overlap_info
window.guide_kf_local_positions = kf_local_pos
if guide_frame_count > 0:
idx = tuple([slice(None)] * dim + [suffix_idx])
sliced_guide = guide_frames[idx]
return torch.cat([video_slice, sliced_guide], dim=dim), guide_frame_count
return video_slice, 0
@dataclass
class WindowingContext:
tensor: torch.Tensor
suffix: torch.Tensor | None
guide_frames: torch.Tensor | None
aux_data: Any
latent_shapes: list | None
is_multimodal: bool
@dataclass
class ContextSchedule:
@ -215,8 +241,8 @@ class IndexListContextHandler(ContextHandlerABC):
self.callbacks = {}
def _get_latent_shapes(self, conds):
"""Extract latent_shapes from conditioning. Returns None if absent."""
@staticmethod
def _get_latent_shapes(conds):
for cond_list in conds:
if cond_list is None:
continue
@ -226,20 +252,20 @@ class IndexListContextHandler(ContextHandlerABC):
return model_conds['latent_shapes'].cond
return None
def _decompose(self, x, latent_shapes):
"""Packed tensor -> list of per-modality tensors."""
@staticmethod
def _unpack(combined_latent, latent_shapes):
if latent_shapes is not None and len(latent_shapes) > 1:
return comfy.utils.unpack_latents(x, latent_shapes)
return [x]
return comfy.utils.unpack_latents(combined_latent, latent_shapes)
return [combined_latent]
def _compose(self, modalities):
"""List of per-modality tensors -> single tensor for pipeline."""
if len(modalities) > 1:
return comfy.utils.pack_latents(modalities)
return modalities[0], [modalities[0].shape]
@staticmethod
def _pack(latents):
if len(latents) > 1:
return comfy.utils.pack_latents(latents)
return latents[0], [latents[0].shape]
def _patch_latent_shapes(self, sub_conds, new_shapes):
"""Patch latent_shapes CONDConstant in (already-copied) sub_conds."""
@staticmethod
def _patch_latent_shapes(sub_conds, new_shapes):
for cond_list in sub_conds:
if cond_list is None:
continue
@ -248,14 +274,48 @@ class IndexListContextHandler(ContextHandlerABC):
if 'latent_shapes' in model_conds:
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
def _build_window_data(self, x_in: torch.Tensor, conds: list[list[dict]]) -> WindowingContext:
latent_shapes = self._get_latent_shapes(conds)
primary = self._decompose(x_in, latent_shapes)[0]
guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0
video_frames = primary.size(self.dim) - guide_count
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
if is_multimodal:
video_latent = comfy.utils.unpack_latents(x_in, latent_shapes)[0]
else:
video_latent = x_in
guide_entries = None
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
guide_entries = entries.cond
break
if guide_entries is not None:
break
guide_frame_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries is not None else 0
primary_frame_count = video_latent.size(self.dim) - guide_frame_count
primary_frames = video_latent.narrow(self.dim, 0, primary_frame_count)
guide_frames = video_latent.narrow(self.dim, primary_frame_count, guide_frame_count) if guide_frame_count > 0 else None
if guide_frame_count > 0:
aux_data = {"guide_entries": guide_entries, "guide_frames": guide_frames}
else:
aux_data = None
return WindowingContext(
tensor=primary_frames, guide_frames=guide_frames, aux_data=aux_data,
latent_shapes=latent_shapes, is_multimodal=is_multimodal)
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
self._window_data = self._build_window_data(x_in, conds)
video_frames = self._window_data.tensor.size(self.dim)
guide_frames = self._window_data.guide_frames.size(self.dim) if self._window_data.guide_frames is not None else 0
if video_frames > self.context_length:
if guide_count > 0:
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_count} guide frames excluded).")
if guide_frames > 0:
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} video frames ({guide_frames} guide frames).")
else:
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {video_frames} frames.")
if self.cond_retain_index_list:
@ -367,15 +427,9 @@ class IndexListContextHandler(ContextHandlerABC):
self._model = model
self.set_step(timestep, model_options)
# Check if multimodal or model has auxiliary frames requiring the extended path
latent_shapes = self._get_latent_shapes(conds)
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
if is_multimodal:
return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, latent_shapes)
window_data = model.prepare_for_windowing(x_in, conds, self.dim)
if window_data.suffix is not None or window_data.aux_data is not None:
return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options,
latent_shapes, window_data)
window_data = self._window_data
if window_data.is_multimodal or (window_data.guide_frames is not None and window_data.guide_frames.size(self.dim) > 0):
return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, window_data)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))
@ -410,101 +464,104 @@ class IndexListContextHandler(ContextHandlerABC):
def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor,
timestep: torch.Tensor, model_options: dict[str],
latent_shapes, window_data: WindowingContext=None):
"""Extended execute path for multimodal models and models with auxiliary frames."""
modalities = self._decompose(x_in, latent_shapes)
is_multimodal = len(modalities) > 1
window_data: WindowingContext):
"""Extended execute path for multimodal models and models with guide frames appended to the noise latent."""
latents = self._unpack(x_in, window_data.latent_shapes)
is_multimodal = window_data.is_multimodal
if window_data is None:
window_data = model.prepare_for_windowing(modalities[0], conds, self.dim)
primary_frames = window_data.tensor
num_guide_frames = window_data.guide_frames.size(self.dim) if window_data.guide_frames is not None else 0
video_primary = window_data.tensor
aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0
context_windows = self.get_context_windows(model, video_primary, model_options)
context_windows = self.get_context_windows(model, primary_frames, model_options)
enumerated_context_windows = list(enumerate(context_windows))
total_windows = len(enumerated_context_windows)
# Accumulators sized to video portion for primary, full for other modalities
accum_modalities = list(modalities)
if window_data.suffix is not None:
accum_modalities[0] = video_primary
accum_shape_refs = list(latents)
if window_data.guide_frames is not None:
accum_shape_refs[0] = primary_frames
accum = [[torch.zeros_like(m) for _ in conds] for m in accum_modalities]
accum = [[torch.zeros_like(m) for _ in conds] for m in accum_shape_refs]
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities]
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs]
else:
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_modalities]
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_modalities]
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in accum_shape_refs]
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in accum_shape_refs]
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
for window_idx, window in enumerated_context_windows:
comfy.model_management.throw_exception_if_processing_interrupted()
logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {video_primary.shape[self.dim]}"
+ (f" (+{aux_count} aux)" if aux_count > 0 else "")
+ (f" [{len(modalities)} modalities]" if is_multimodal else ""))
logging.info(f"Context window {window_idx + 1}/{total_windows}: frames {window.index_list[0]}-{window.index_list[-1]} of {primary_frames.shape[self.dim]}"
+ (f" (+{num_guide_frames} guide frames)" if num_guide_frames > 0 else "")
+ (f" [{len(latents)} modalities]" if is_multimodal else ""))
# Per-modality window indices
if is_multimodal:
map_shapes = latent_shapes
if video_primary.size(self.dim) != modalities[0].size(self.dim):
map_shapes = list(latent_shapes)
video_shape = list(latent_shapes[0])
video_shape[self.dim] = video_primary.size(self.dim)
map_shapes = window_data.latent_shapes
if primary_frames.size(self.dim) != latents[0].size(self.dim):
map_shapes = list(window_data.latent_shapes)
video_shape = list(window_data.latent_shapes[0])
video_shape[self.dim] = primary_frames.size(self.dim)
map_shapes[0] = torch.Size(video_shape)
per_mod_indices = model.map_context_window_to_modalities(
window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list]
try:
per_modality_indices = model.map_context_window_to_modalities(
window.index_list, map_shapes, self.dim)
except AttributeError:
raise NotImplementedError(
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
modality_windows = {}
for mod_idx in range(1, len(modalities)):
for mod_idx in range(1, len(latents)):
modality_windows[mod_idx] = IndexListContextWindow(
per_mod_indices[mod_idx], dim=self.dim,
total_frames=modalities[mod_idx].shape[self.dim])
per_modality_indices[mod_idx], dim=self.dim,
total_frames=latents[mod_idx].shape[self.dim])
window = IndexListContextWindow(
window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim],
window.index_list, dim=self.dim, total_frames=primary_frames.shape[self.dim],
modality_windows=modality_windows)
# Build per-modality windows list
mod_windows = [window]
per_modality_windows_list = [window]
if is_multimodal:
for mod_idx in range(1, len(modalities)):
mod_windows.append(modality_windows[mod_idx])
for mod_idx in range(1, len(latents)):
per_modality_windows_list.append(modality_windows[mod_idx])
# Slice video, then let model inject auxiliary frames
sliced_video = mod_windows[0].get_tensor(video_primary, retain_index_list=self.cond_retain_index_list)
sliced_primary, num_aux = model.prepare_window_input(
sliced_video, window, window_data.aux_data, self.dim)
sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))]
# Slice video, then inject overlapping guide frames if present
sliced_video = per_modality_windows_list[0].get_tensor(primary_frames, retain_index_list=self.cond_retain_index_list)
if window_data.aux_data is not None:
sliced_primary, num_guide_frames = inject_guide_frames_into_window(sliced_video, window, window_data.aux_data, self.dim)
else:
sliced_primary, num_guide_frames = sliced_video, 0
sliced = [sliced_primary] + [per_modality_windows_list[mi].get_tensor(latents[mi]) for mi in range(1, len(latents))]
sub_x, sub_shapes = self._compose(sliced)
sub_x, sub_shapes = self._pack(sliced)
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None)
model_options["transformer_options"]["context_window"] = window
sub_timestep = window.get_tensor(timestep, dim=0)
sub_conds = [self.get_resized_cond(cond, video_primary, window) for cond in conds]
sub_conds = [self.get_resized_cond(cond, primary_frames, window) for cond in conds]
if is_multimodal:
self._patch_latent_shapes(sub_conds, sub_shapes)
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
# Decompose output per modality
out_per_mod = [self._decompose(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
# Unpack output per modality
out_per_modality = [self._unpack(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
# Strip auxiliary frames from primary output before accumulation
if num_aux > 0:
# Strip guide frames from primary output before accumulation
if num_guide_frames > 0:
window_len = len(window.index_list)
for ci in range(len(sub_conds_out)):
out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len)
out_per_modality[ci][0] = out_per_modality[ci][0].narrow(self.dim, 0, window_len)
# Accumulate per modality
for mod_idx in range(len(accum_modalities)):
mw = mod_windows[mod_idx]
mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))]
for mod_idx in range(len(accum_shape_refs)):
mw = per_modality_windows_list[mod_idx]
sub_conds_out_per_modality = [out_per_modality[ci][mod_idx] for ci in range(len(sub_conds_out))]
self.combine_context_window_results(
accum_modalities[mod_idx], mod_sub_out, sub_conds, mw,
accum_shape_refs[mod_idx], sub_conds_out_per_modality, sub_conds, mw,
window_idx, total_windows, timestep,
accum[mod_idx], counts[mod_idx], biases[mod_idx])
@ -512,15 +569,15 @@ class IndexListContextHandler(ContextHandlerABC):
result = []
for ci in range(len(conds)):
finalized = []
for mod_idx in range(len(accum_modalities)):
for mod_idx in range(len(accum_shape_refs)):
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
accum[mod_idx][ci] /= counts[mod_idx][ci]
f = accum[mod_idx][ci]
if mod_idx == 0 and window_data.suffix is not None:
f = torch.cat([f, window_data.suffix], dim=self.dim)
if mod_idx == 0 and window_data.guide_frames is not None:
f = torch.cat([f, window_data.guide_frames], dim=self.dim)
finalized.append(f)
composed, _ = self._compose(finalized)
result.append(composed)
packed, _ = self._pack(finalized)
result.append(packed)
return result
finally:
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
@ -616,11 +673,8 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
# For packed multimodal tensors (e.g. LTXAV), noise is [B, 1, flat] and FreeNoise
# must only shuffle the video portion. Unpack, apply to video, repack.
latent_shapes = None
try:
latent_shapes = guider.conds['positive'][0]['model_conds']['latent_shapes'].cond
except (KeyError, IndexError, AttributeError):
pass
latent_shapes = IndexListContextHandler._get_latent_shapes(
[guider.conds.get('positive', guider.conds.get('negative', []))])
if latent_shapes is not None and len(latent_shapes) > 1:
modalities = comfy.utils.unpack_latents(noise, latent_shapes)

View File

@ -287,12 +287,6 @@ class BaseModel(torch.nn.Module):
return data
return None
def prepare_for_windowing(self, primary, conds, dim):
return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None)
def prepare_window_input(self, video_slice, window, aux_data, dim):
return video_slice, 0
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
"""Override in subclasses to handle model-specific cond slicing for context windows.
Return a sliced cond object, or None to fall through to default handling.
@ -1098,7 +1092,7 @@ class LTXAV(BaseModel):
for i in range(1, len(latent_shapes)):
mod_total = latent_shapes[i][dim]
# Length proportional to video window frame count (not index span)
# Length proportional to video window frame count
mod_window_len = max(round(video_window_len * mod_total / video_total), 1)
# Anchor to end of video range
v_end = max(primary_indices) + 1
@ -1108,17 +1102,6 @@ class LTXAV(BaseModel):
return result
def get_guide_frame_count(self, x, conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
gae = model_conds.get('guide_attention_entries')
if gae is not None and hasattr(gae, 'cond') and gae.cond:
return sum(e["latent_shape"][0] for e in gae.cond)
return 0
@staticmethod
def _get_guide_entries(conds):
for cond_list in conds:
@ -1126,43 +1109,27 @@ class LTXAV(BaseModel):
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
gae = model_conds.get('guide_attention_entries')
if gae is not None and hasattr(gae, 'cond') and gae.cond:
return gae.cond
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
return entries.cond
return None
def prepare_for_windowing(self, primary, conds, dim):
guide_count = self.get_guide_frame_count(primary, conds)
def prepare_window_data(self, x_in, conds, dim, window_data):
primary = comfy.utils.unpack_latents(x_in, window_data.latent_shapes)[0] if window_data.is_multimodal else x_in
guide_entries = self._get_guide_entries(conds)
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
if guide_count <= 0:
return comfy.context_windows.WindowingContext(tensor=primary, suffix=None, aux_data=None)
return comfy.context_windows.WindowingContext(
tensor=primary, guide_frames=None, aux_data=None,
latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal)
video_len = primary.size(dim) - guide_count
video_primary = primary.narrow(dim, 0, video_len)
guide_suffix = primary.narrow(dim, video_len, guide_count)
guide_entries = self._get_guide_entries(conds)
guide_frames = primary.narrow(dim, video_len, guide_count)
return comfy.context_windows.WindowingContext(
tensor=video_primary, suffix=guide_suffix,
aux_data={"guide_entries": guide_entries, "guide_suffix": guide_suffix})
tensor=video_primary, guide_frames=guide_frames,
aux_data={"guide_entries": guide_entries, "guide_frames": guide_frames},
latent_shapes=window_data.latent_shapes, is_multimodal=window_data.is_multimodal)
def prepare_window_input(self, video_slice, window, aux_data, dim):
if aux_data is None:
return video_slice, 0
guide_entries = aux_data["guide_entries"]
guide_suffix = aux_data["guide_suffix"]
if guide_entries is None:
window.guide_suffix_indices = []
window.guide_overlap_info = []
window.guide_kf_local_positions = []
return video_slice, 0
overlap = comfy.context_windows._compute_guide_overlap(guide_entries, window.index_list)
suffix_idx, overlap_info, kf_local_pos, num_guide = overlap
window.guide_suffix_indices = suffix_idx
window.guide_overlap_info = overlap_info
window.guide_kf_local_positions = kf_local_pos
if num_guide > 0:
idx = tuple([slice(None)] * dim + [suffix_idx])
sliced_guide = guide_suffix[idx]
return torch.cat([video_slice, sliced_guide], dim=dim), num_guide
return video_slice, 0
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
# Audio denoise mask — slice using audio modality window
@ -1181,7 +1148,7 @@ class LTXAV(BaseModel):
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
suffix_indices = window.guide_suffix_indices
suffix_indices = window.guide_frames_indices
if suffix_indices:
idx = tuple([slice(None)] * window.dim + [suffix_indices])
sliced_guide = guide_mask[idx].to(device)
@ -1199,14 +1166,31 @@ class LTXAV(BaseModel):
patchifier = self.diffusion_model.patchifier
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
scale_factors = self.diffusion_model.vae_scale_factors
pixel_coords = latent_to_pixel_coords(
latent_coords,
self.diffusion_model.vae_scale_factors,
scale_factors,
causal_fix=self.diffusion_model.causal_temporal_positioning)
tokens = []
for pos in kf_local_pos:
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
pixel_coords = pixel_coords[:, :, tokens, :]
# Adjust spatial end positions for dilated (downscaled) guides.
# Each guide entry may have a different downscale factor; expand the
# per-entry factor to cover all tokens belonging to that entry.
downscale_factors = getattr(window, 'guide_downscale_factors', [])
overlap_info = window.guide_overlap_info
if downscale_factors:
per_token_factor = []
for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors):
per_token_factor.extend([dsf] * (overlap_count * H * W))
factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype)
spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor(
scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype,
).view(1, -1, 1, 1)
pixel_coords[:, 1:, :, 1:] += spatial_end_offset
B = cond_value.cond.shape[0]
if B > 1:
pixel_coords = pixel_coords.expand(B, -1, -1, -1)