mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
LTX2 context windows - Cleanup: Simplify IndexListContextHandler standard execute path
This commit is contained in:
parent
874690c01c
commit
3a061f4bbf
@ -367,18 +367,60 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
self._model = model
|
self._model = model
|
||||||
self.set_step(timestep, model_options)
|
self.set_step(timestep, model_options)
|
||||||
|
|
||||||
# Decompose — single-modality: [x_in], multimodal: [video, audio, ...]
|
# Check if multimodal or model has auxiliary frames requiring the extended path
|
||||||
latent_shapes = self._get_latent_shapes(conds)
|
latent_shapes = self._get_latent_shapes(conds)
|
||||||
|
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
|
||||||
|
if is_multimodal:
|
||||||
|
return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options, latent_shapes)
|
||||||
|
window_data = model.prepare_for_windowing(x_in, conds, self.dim)
|
||||||
|
if window_data.suffix is not None or window_data.aux_data is not None:
|
||||||
|
return self._execute_extended(calc_cond_batch, model, conds, x_in, timestep, model_options,
|
||||||
|
latent_shapes, window_data)
|
||||||
|
|
||||||
|
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||||
|
enumerated_context_windows = list(enumerate(context_windows))
|
||||||
|
|
||||||
|
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||||
|
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||||
|
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||||
|
else:
|
||||||
|
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||||
|
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||||
|
|
||||||
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||||
|
callback(self, model, x_in, conds, timestep, model_options)
|
||||||
|
|
||||||
|
for enum_window in enumerated_context_windows:
|
||||||
|
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||||
|
for result in results:
|
||||||
|
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||||
|
conds_final, counts_final, biases_final)
|
||||||
|
try:
|
||||||
|
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||||
|
del counts_final
|
||||||
|
return conds_final
|
||||||
|
else:
|
||||||
|
for i in range(len(conds_final)):
|
||||||
|
conds_final[i] /= counts_final[i]
|
||||||
|
del counts_final
|
||||||
|
return conds_final
|
||||||
|
finally:
|
||||||
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||||
|
callback(self, model, x_in, conds, timestep, model_options)
|
||||||
|
|
||||||
|
def _execute_extended(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor,
|
||||||
|
timestep: torch.Tensor, model_options: dict[str],
|
||||||
|
latent_shapes, window_data: WindowingContext=None):
|
||||||
|
"""Extended execute path for multimodal models and models with auxiliary frames."""
|
||||||
modalities = self._decompose(x_in, latent_shapes)
|
modalities = self._decompose(x_in, latent_shapes)
|
||||||
is_multimodal = len(modalities) > 1
|
is_multimodal = len(modalities) > 1
|
||||||
primary = modalities[0]
|
|
||||||
|
|
||||||
# Let model strip auxiliary frames (e.g. guide frames)
|
if window_data is None:
|
||||||
window_data = model.prepare_for_windowing(primary, conds, self.dim)
|
window_data = model.prepare_for_windowing(modalities[0], conds, self.dim)
|
||||||
|
|
||||||
video_primary = window_data.tensor
|
video_primary = window_data.tensor
|
||||||
aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0
|
aux_count = window_data.suffix.size(self.dim) if window_data.suffix is not None else 0
|
||||||
|
|
||||||
# Windows from video portion only
|
|
||||||
context_windows = self.get_context_windows(model, video_primary, model_options)
|
context_windows = self.get_context_windows(model, video_primary, model_options)
|
||||||
enumerated_context_windows = list(enumerate(context_windows))
|
enumerated_context_windows = list(enumerate(context_windows))
|
||||||
total_windows = len(enumerated_context_windows)
|
total_windows = len(enumerated_context_windows)
|
||||||
@ -407,14 +449,13 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
# Per-modality window indices
|
# Per-modality window indices
|
||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
map_shapes = latent_shapes
|
map_shapes = latent_shapes
|
||||||
if video_primary.size(self.dim) != primary.size(self.dim):
|
if video_primary.size(self.dim) != modalities[0].size(self.dim):
|
||||||
map_shapes = list(latent_shapes)
|
map_shapes = list(latent_shapes)
|
||||||
video_shape = list(latent_shapes[0])
|
video_shape = list(latent_shapes[0])
|
||||||
video_shape[self.dim] = video_primary.size(self.dim)
|
video_shape[self.dim] = video_primary.size(self.dim)
|
||||||
map_shapes[0] = torch.Size(video_shape)
|
map_shapes[0] = torch.Size(video_shape)
|
||||||
per_mod_indices = model.map_context_window_to_modalities(
|
per_mod_indices = model.map_context_window_to_modalities(
|
||||||
window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list]
|
window.index_list, map_shapes, self.dim) if hasattr(model, 'map_context_window_to_modalities') else [window.index_list]
|
||||||
# Build per-modality windows and attach to primary window
|
|
||||||
modality_windows = {}
|
modality_windows = {}
|
||||||
for mod_idx in range(1, len(modalities)):
|
for mod_idx in range(1, len(modalities)):
|
||||||
modality_windows[mod_idx] = IndexListContextWindow(
|
modality_windows[mod_idx] = IndexListContextWindow(
|
||||||
@ -423,11 +464,9 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
window = IndexListContextWindow(
|
window = IndexListContextWindow(
|
||||||
window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim],
|
window.index_list, dim=self.dim, total_frames=video_primary.shape[self.dim],
|
||||||
modality_windows=modality_windows)
|
modality_windows=modality_windows)
|
||||||
else:
|
|
||||||
per_mod_indices = [window.index_list]
|
|
||||||
|
|
||||||
# Build per-modality windows list (including primary)
|
# Build per-modality windows list
|
||||||
mod_windows = [window] # primary window at index 0
|
mod_windows = [window]
|
||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
for mod_idx in range(1, len(modalities)):
|
for mod_idx in range(1, len(modalities)):
|
||||||
mod_windows.append(modality_windows[mod_idx])
|
mod_windows.append(modality_windows[mod_idx])
|
||||||
@ -438,10 +477,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
sliced_video, window, window_data.aux_data, self.dim)
|
sliced_video, window, window_data.aux_data, self.dim)
|
||||||
sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))]
|
sliced = [sliced_primary] + [mod_windows[mi].get_tensor(modalities[mi]) for mi in range(1, len(modalities))]
|
||||||
|
|
||||||
# Compose for pipeline
|
|
||||||
sub_x, sub_shapes = self._compose(sliced)
|
sub_x, sub_shapes = self._compose(sliced)
|
||||||
|
|
||||||
# Callbacks
|
|
||||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None)
|
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, None, None)
|
||||||
|
|
||||||
@ -462,7 +499,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
for ci in range(len(sub_conds_out)):
|
for ci in range(len(sub_conds_out)):
|
||||||
out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len)
|
out_per_mod[ci][0] = out_per_mod[ci][0].narrow(self.dim, 0, window_len)
|
||||||
|
|
||||||
# Accumulate per modality (using video-only sizes)
|
# Accumulate per modality
|
||||||
for mod_idx in range(len(accum_modalities)):
|
for mod_idx in range(len(accum_modalities)):
|
||||||
mw = mod_windows[mod_idx]
|
mw = mod_windows[mod_idx]
|
||||||
mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))]
|
mod_sub_out = [out_per_mod[ci][mod_idx] for ci in range(len(sub_conds_out))]
|
||||||
@ -479,7 +516,6 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
||||||
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
||||||
f = accum[mod_idx][ci]
|
f = accum[mod_idx][ci]
|
||||||
# Re-append model's suffix (auxiliary frames stripped before windowing)
|
|
||||||
if mod_idx == 0 and window_data.suffix is not None:
|
if mod_idx == 0 and window_data.suffix is not None:
|
||||||
f = torch.cat([f, window_data.suffix], dim=self.dim)
|
f = torch.cat([f, window_data.suffix], dim=self.dim)
|
||||||
finalized.append(f)
|
finalized.append(f)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user