mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 15:32:32 +08:00
LTX2 context windows - Thread per-modality overlap into fuse weights
This commit is contained in:
parent
f72583d1f3
commit
a8b084ed58
@ -53,9 +53,10 @@ class ContextHandlerABC(ABC):
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.context_overlap = context_overlap
|
||||
self.dim = dim
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
@ -211,8 +212,10 @@ class WindowingState:
|
||||
return window
|
||||
|
||||
x = self.latents[0]
|
||||
primary_total = self.latent_shapes[0][self.dim]
|
||||
primary_overlap = window.context_overlap
|
||||
map_shapes = self.latent_shapes
|
||||
if x.size(self.dim) != self.latent_shapes[0][self.dim]:
|
||||
if x.size(self.dim) != primary_total:
|
||||
map_shapes = list(self.latent_shapes)
|
||||
video_shape = list(self.latent_shapes[0])
|
||||
video_shape[self.dim] = x.size(self.dim)
|
||||
@ -225,12 +228,16 @@ class WindowingState:
|
||||
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
|
||||
modality_windows = {}
|
||||
for mod_idx in range(1, len(self.latents)):
|
||||
modality_total_frames = self.latents[mod_idx].shape[self.dim]
|
||||
ratio = modality_total_frames / primary_total if primary_total > 0 else 1
|
||||
modality_overlap = max(round(primary_overlap * ratio), 0)
|
||||
modality_windows[mod_idx] = IndexListContextWindow(
|
||||
per_modality_indices[mod_idx], dim=self.dim,
|
||||
total_frames=self.latents[mod_idx].shape[self.dim])
|
||||
total_frames=modality_total_frames,
|
||||
context_overlap=modality_overlap)
|
||||
return IndexListContextWindow(
|
||||
window.index_list, dim=self.dim, total_frames=x.shape[self.dim],
|
||||
modality_windows=modality_windows)
|
||||
modality_windows=modality_windows, context_overlap=primary_overlap)
|
||||
|
||||
def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]:
|
||||
"""Slice latents for a context window, injecting guide frames where applicable.
|
||||
@ -501,7 +508,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
@ -646,7 +653,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
else:
|
||||
# add conds and counts based on weights of fuse method
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap)
|
||||
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||
for i in range(len(sub_conds_out)):
|
||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||
@ -849,8 +856,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||
return ContextSchedule(context_schedule, func)
|
||||
|
||||
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None):
|
||||
context_overlap = handler.context_overlap if context_overlap is None else context_overlap
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap)
|
||||
|
||||
|
||||
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||
@ -868,18 +876,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||
return weight_sequence
|
||||
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs):
|
||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||
# only expected overlap is given different weights
|
||||
weights_torch = torch.ones((length))
|
||||
# blend left-side on all except first window
|
||||
if min(idxs) > 0:
|
||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||
weights_torch[:handler.context_overlap] = ramp_up
|
||||
ramp_up = torch.linspace(1e-37, 1, context_overlap)
|
||||
weights_torch[:context_overlap] = ramp_up
|
||||
# blend right-side on all except last window
|
||||
if max(idxs) < full_length-1:
|
||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||
weights_torch[-handler.context_overlap:] = ramp_down
|
||||
ramp_down = torch.linspace(1, 1e-37, context_overlap)
|
||||
weights_torch[-context_overlap:] = ramp_down
|
||||
return weights_torch
|
||||
|
||||
class ContextFuseMethods:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user