mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 20:42:31 +08:00
LTX2 context windows - Cleanup: Remove model specific code from BaseModel. Older LTXV model's guides + context_windows will need to be re-implemented but outside the scope of LTX2 changes
This commit is contained in:
parent
3660533f83
commit
350237618d
@ -257,7 +257,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
primary = self._decompose(x_in, latent_shapes)[0]
|
||||
guide_count = model.get_guide_frame_count(primary, conds) if model is not None else 0
|
||||
guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0
|
||||
video_frames = primary.size(self.dim) - guide_count
|
||||
if video_frames > self.context_length:
|
||||
if guide_count > 0:
|
||||
@ -380,7 +380,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
primary = modalities[0]
|
||||
|
||||
# Separate guide frames from primary modality (guides are appended at the end)
|
||||
guide_count = model.get_guide_frame_count(primary, conds) if model is not None else 0
|
||||
guide_count = model.get_guide_frame_count(primary, conds) if hasattr(model, 'get_guide_frame_count') else 0
|
||||
if guide_count > 0:
|
||||
video_len = primary.size(self.dim) - guide_count
|
||||
video_primary = primary.narrow(self.dim, 0, video_len)
|
||||
@ -427,7 +427,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
video_shape[self.dim] = video_shape[self.dim] - guide_count
|
||||
map_shapes[0] = torch.Size(video_shape)
|
||||
per_mod_indices = model.map_context_window_to_modalities(
|
||||
window.index_list, map_shapes, self.dim)
|
||||
window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list]
|
||||
# Build per-modality windows and attach to primary window
|
||||
modality_windows = {}
|
||||
for mod_idx in range(1, len(modalities)):
|
||||
|
||||
@ -293,71 +293,6 @@ class BaseModel(torch.nn.Module):
|
||||
Use comfy.context_windows.slice_cond() for common cases."""
|
||||
return None
|
||||
|
||||
def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
|
||||
"""Map primary modality's window indices to all modalities.
|
||||
Returns list of index lists, one per modality."""
|
||||
return [primary_indices]
|
||||
|
||||
def get_guide_frame_count(self, x, conds):
|
||||
"""Return the number of trailing guide frames appended to x along the temporal dim.
|
||||
Override in subclasses that concatenate guide reference frames to the latent."""
|
||||
return 0
|
||||
|
||||
def _resize_guide_cond(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
"""Resize guide-related conditioning for context windows.
|
||||
Requires guide_suffix_indices, guide_overlap_info, and guide_kf_local_positions
|
||||
to be set on the window by _compute_guide_overlap."""
|
||||
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
cond_tensor = cond_value.cond
|
||||
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
|
||||
if guide_count > 0:
|
||||
T_video = x_in.size(window.dim)
|
||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||
suffix_indices = window.guide_suffix_indices
|
||||
if suffix_indices:
|
||||
idx = tuple([slice(None)] * window.dim + [suffix_indices])
|
||||
sliced_guide = guide_mask[idx].to(device)
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
else:
|
||||
return cond_value._copy_with(sliced_video)
|
||||
|
||||
if cond_key == "keyframe_idxs":
|
||||
kf_local_pos = window.guide_kf_local_positions
|
||||
if not kf_local_pos:
|
||||
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
window_len = len(window.index_list)
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
self.diffusion_model.vae_scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
tokens = []
|
||||
for pos in kf_local_pos:
|
||||
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
|
||||
pixel_coords = pixel_coords[:, :, tokens, :]
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
|
||||
if cond_key == "guide_attention_entries":
|
||||
overlap_info = window.guide_overlap_info
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
new_entries = []
|
||||
for entry_idx, overlap_count in overlap_info:
|
||||
e = cond_value.cond[entry_idx]
|
||||
new_entries.append({**e,
|
||||
"pre_filter_count": overlap_count * H * W,
|
||||
"latent_shape": [overlap_count, H, W]})
|
||||
return cond_value._copy_with(new_entries)
|
||||
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
@ -1081,20 +1016,6 @@ class LTXV(BaseModel):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
def get_guide_frame_count(self, x, conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
gae = model_conds.get('guide_attention_entries')
|
||||
if gae is not None and hasattr(gae, 'cond') and gae.cond:
|
||||
return sum(e["latent_shape"][0] for e in gae.cond)
|
||||
return 0
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
return self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list)
|
||||
|
||||
class LTXAV(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO
|
||||
@ -1193,17 +1114,64 @@ class LTXAV(BaseModel):
|
||||
return 0
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
# Audio-specific handling
|
||||
# Audio denoise mask — slice using audio modality window
|
||||
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
|
||||
audio_window = window.modality_windows.get(1)
|
||||
if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
# Guide handling (shared with LTXV)
|
||||
result = self._resize_guide_cond(cond_key, cond_value, window, x_in, device, retain_index_list)
|
||||
if result is not None:
|
||||
return result
|
||||
# Video denoise mask — split into video + guide portions, slice each
|
||||
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
cond_tensor = cond_value.cond
|
||||
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
|
||||
if guide_count > 0:
|
||||
T_video = x_in.size(window.dim)
|
||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||
suffix_indices = window.guide_suffix_indices
|
||||
if suffix_indices:
|
||||
idx = tuple([slice(None)] * window.dim + [suffix_indices])
|
||||
sliced_guide = guide_mask[idx].to(device)
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
else:
|
||||
return cond_value._copy_with(sliced_video)
|
||||
|
||||
# Keyframe indices — regenerate pixel coords for window, select guide positions
|
||||
if cond_key == "keyframe_idxs":
|
||||
kf_local_pos = window.guide_kf_local_positions
|
||||
if not kf_local_pos:
|
||||
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
window_len = len(window.index_list)
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import latent_to_pixel_coords
|
||||
pixel_coords = latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
self.diffusion_model.vae_scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
tokens = []
|
||||
for pos in kf_local_pos:
|
||||
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
|
||||
pixel_coords = pixel_coords[:, :, tokens, :]
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
|
||||
# Guide attention entries — adjust per-guide counts based on window overlap
|
||||
if cond_key == "guide_attention_entries":
|
||||
overlap_info = window.guide_overlap_info
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
new_entries = []
|
||||
for entry_idx, overlap_count in overlap_info:
|
||||
e = cond_value.cond[entry_idx]
|
||||
new_entries.append({**e,
|
||||
"pre_filter_count": overlap_count * H * W,
|
||||
"latent_shape": [overlap_count, H, W]})
|
||||
return cond_value._copy_with(new_entries)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user