mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
Merge 1356f0158c into 7bbf1e8169
This commit is contained in:
commit
71f7241baa
@ -22,26 +22,25 @@ class CompressedTimestep:
|
|||||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||||
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
|
||||||
|
|
||||||
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
|
def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False):
|
||||||
"""
|
"""
|
||||||
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
|
tensor: [batch, num_tokens, feature_dim] (per-token, default) or
|
||||||
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
|
[batch, num_frames, feature_dim] (per_frame=True, already compressed).
|
||||||
|
patches_per_frame: spatial patches per frame; pass None to disable compression.
|
||||||
"""
|
"""
|
||||||
self.batch_size, num_tokens, self.feature_dim = tensor.shape
|
self.batch_size, n, self.feature_dim = tensor.shape
|
||||||
|
if per_frame:
|
||||||
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
|
|
||||||
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
|
|
||||||
self.patches_per_frame = patches_per_frame
|
self.patches_per_frame = patches_per_frame
|
||||||
self.num_frames = num_tokens // patches_per_frame
|
self.num_frames = n
|
||||||
|
self.data = tensor
|
||||||
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
|
elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0:
|
||||||
# All patches in a frame are identical, so we only keep the first one
|
self.patches_per_frame = patches_per_frame
|
||||||
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
|
self.num_frames = n // patches_per_frame
|
||||||
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
|
# All patches in a frame are identical — keep only the first.
|
||||||
|
self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous()
|
||||||
else:
|
else:
|
||||||
# Not divisible or too small - store directly without compression
|
|
||||||
self.patches_per_frame = 1
|
self.patches_per_frame = 1
|
||||||
self.num_frames = num_tokens
|
self.num_frames = n
|
||||||
self.data = tensor
|
self.data = tensor
|
||||||
|
|
||||||
def expand(self):
|
def expand(self):
|
||||||
@ -716,32 +715,35 @@ class LTXAVModel(LTXVModel):
|
|||||||
|
|
||||||
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
||||||
"""Prepare timestep embeddings."""
|
"""Prepare timestep embeddings."""
|
||||||
# TODO: some code reuse is needed here.
|
|
||||||
grid_mask = kwargs.get("grid_mask", None)
|
grid_mask = kwargs.get("grid_mask", None)
|
||||||
if grid_mask is not None:
|
|
||||||
timestep = timestep[:, grid_mask]
|
|
||||||
|
|
||||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
|
||||||
|
|
||||||
v_timestep, v_embedded_timestep = self.adaln_single(
|
|
||||||
timestep_scaled.flatten(),
|
|
||||||
{"resolution": None, "aspect_ratio": None},
|
|
||||||
batch_size=batch_size,
|
|
||||||
hidden_dtype=hidden_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
|
|
||||||
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
|
|
||||||
orig_shape = kwargs.get("orig_shape")
|
orig_shape = kwargs.get("orig_shape")
|
||||||
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
has_spatial_mask = kwargs.get("has_spatial_mask", None)
|
||||||
v_patches_per_frame = None
|
v_patches_per_frame = None
|
||||||
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
|
||||||
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
|
|
||||||
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
v_patches_per_frame = orig_shape[3] * orig_shape[4]
|
||||||
|
|
||||||
# Reshape to [batch_size, num_tokens, dim] and compress for storage
|
# Used by compute_prompt_timestep and the audio cross-attention paths.
|
||||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier
|
||||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
|
||||||
|
# When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token
|
||||||
|
per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0
|
||||||
|
if per_frame_path:
|
||||||
|
per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0]
|
||||||
|
if grid_mask is not None:
|
||||||
|
# All-or-nothing per frame when has_spatial_mask=False.
|
||||||
|
per_frame = per_frame[:, grid_mask[::v_patches_per_frame]]
|
||||||
|
ts_input = per_frame * self.timestep_scale_multiplier
|
||||||
|
else:
|
||||||
|
ts_input = timestep_scaled
|
||||||
|
|
||||||
|
v_timestep, v_embedded_timestep = self.adaln_single(
|
||||||
|
ts_input.flatten(),
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
|
||||||
|
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
|
||||||
|
|
||||||
v_prompt_timestep = compute_prompt_timestep(
|
v_prompt_timestep = compute_prompt_timestep(
|
||||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||||
|
|||||||
@ -358,6 +358,61 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
|
|||||||
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
||||||
|
|
||||||
|
|
||||||
|
class GuideAttentionMask:
|
||||||
|
"""Holds the two per-group masks for LTXV guide self-attention.
|
||||||
|
_attention_with_guide_mask splits queries into noisy and tracked-guide
|
||||||
|
groups, so the largest mask is (1, 1, tracked_count, T).
|
||||||
|
"""
|
||||||
|
__slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask")
|
||||||
|
|
||||||
|
def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights):
|
||||||
|
device = tracked_weights.device
|
||||||
|
dtype = tracked_weights.dtype
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
|
||||||
|
pos = tracked_weights > 0
|
||||||
|
log_w = torch.full_like(tracked_weights, finfo.min)
|
||||||
|
log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny))
|
||||||
|
|
||||||
|
self.guide_start = guide_start
|
||||||
|
self.tracked_count = tracked_count
|
||||||
|
|
||||||
|
self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
|
||||||
|
self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1)
|
||||||
|
|
||||||
|
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
|
||||||
|
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
|
||||||
|
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
|
||||||
|
groups, so each group needs only its own sub-mask. Avoids materializing
|
||||||
|
the (1,1,T,T) dense mask.
|
||||||
|
"""
|
||||||
|
guide_start = guide_mask.guide_start
|
||||||
|
tracked_end = guide_start + guide_mask.tracked_count
|
||||||
|
|
||||||
|
out = torch.empty_like(q)
|
||||||
|
|
||||||
|
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
|
||||||
|
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||||
|
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
|
||||||
|
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||||
|
low_precision_attention=False, # sageattn mask support is unreliable
|
||||||
|
)
|
||||||
|
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||||
|
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
|
||||||
|
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||||
|
low_precision_attention=False,
|
||||||
|
)
|
||||||
|
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
|
||||||
|
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
|
||||||
|
q[:, tracked_end:, :], k, v, heads,
|
||||||
|
attn_precision=attn_precision, transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -412,8 +467,10 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
|
elif isinstance(mask, GuideAttentionMask):
|
||||||
|
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
else:
|
else:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
|
|
||||||
# Apply per-head gating if enabled
|
# Apply per-head gating if enabled
|
||||||
if self.to_gate_logits is not None:
|
if self.to_gate_logits is not None:
|
||||||
@ -1063,7 +1120,9 @@ class LTXVModel(LTXBaseModel):
|
|||||||
additional_args["resolved_guide_entries"] = resolved_entries
|
additional_args["resolved_guide_entries"] = resolved_entries
|
||||||
|
|
||||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
|
||||||
|
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
|
||||||
|
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||||
|
|
||||||
# Total surviving guide tokens (all guides)
|
# Total surviving guide tokens (all guides)
|
||||||
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
||||||
@ -1099,12 +1158,12 @@ class LTXVModel(LTXBaseModel):
|
|||||||
if not resolved_entries:
|
if not resolved_entries:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check if any attenuation is actually needed
|
# strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
|
||||||
needs_attenuation = any(
|
needs_mask = any(
|
||||||
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
e["strength"] != 1.0 or e.get("pixel_mask") is not None
|
||||||
for e in resolved_entries
|
for e in resolved_entries
|
||||||
)
|
)
|
||||||
if not needs_attenuation:
|
if not needs_mask:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Build per-guide-token weights for all tracked guide tokens.
|
# Build per-guide-token weights for all tracked guide tokens.
|
||||||
@ -1159,16 +1218,11 @@ class LTXVModel(LTXBaseModel):
|
|||||||
# Concatenate per-token weights for all tracked guides
|
# Concatenate per-token weights for all tracked guides
|
||||||
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
||||||
|
|
||||||
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
# Skip when every weight is exactly 1.0 (additive bias would be 0).
|
||||||
if (tracked_weights >= 1.0).all():
|
if (tracked_weights == 1.0).all():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Build the mask: guide tokens are at the end of the sequence.
|
return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)
|
||||||
# Tracked guides come first (in order), untracked follow.
|
|
||||||
return self._build_self_attention_mask(
|
|
||||||
total_tokens, num_guide_tokens, total_tracked,
|
|
||||||
tracked_weights, guide_start, device, dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
||||||
@ -1234,45 +1288,6 @@ class LTXVModel(LTXBaseModel):
|
|||||||
|
|
||||||
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
|
||||||
tracked_weights, guide_start, device, dtype):
|
|
||||||
"""Build a log-space additive self-attention bias mask.
|
|
||||||
|
|
||||||
Attenuates attention between noisy tokens and tracked guide tokens.
|
|
||||||
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
total_tokens: Total sequence length.
|
|
||||||
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
|
||||||
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
|
||||||
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
|
||||||
guide_start: Index where guide tokens begin in the sequence.
|
|
||||||
device: Target device.
|
|
||||||
dtype: Target dtype.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(1, 1, total_tokens, total_tokens) additive bias mask.
|
|
||||||
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
|
||||||
"""
|
|
||||||
finfo = torch.finfo(dtype)
|
|
||||||
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
|
||||||
tracked_end = guide_start + tracked_count
|
|
||||||
|
|
||||||
# Convert weights to log-space bias
|
|
||||||
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
|
||||||
log_w = torch.full_like(w, finfo.min)
|
|
||||||
positive_mask = w > 0
|
|
||||||
if positive_mask.any():
|
|
||||||
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
|
||||||
|
|
||||||
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
|
||||||
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
|
||||||
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
|
||||||
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
||||||
"""Process transformer blocks for LTXV."""
|
"""Process transformer blocks for LTXV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|||||||
@ -219,7 +219,7 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
|
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
|
||||||
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
|
||||||
),
|
),
|
||||||
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
|
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Conditioning.Output(display_name="positive"),
|
io.Conditioning.Output(display_name="positive"),
|
||||||
@ -298,7 +298,7 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
mask = torch.full(
|
mask = torch.full(
|
||||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||||
1.0 - strength,
|
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
|
||||||
dtype=noise_mask.dtype,
|
dtype=noise_mask.dtype,
|
||||||
device=noise_mask.device,
|
device=noise_mask.device,
|
||||||
)
|
)
|
||||||
@ -318,7 +318,7 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
|
|
||||||
mask = torch.full(
|
mask = torch.full(
|
||||||
(noise_mask.shape[0], 1, cond_length, 1, 1),
|
(noise_mask.shape[0], 1, cond_length, 1, 1),
|
||||||
1.0 - strength,
|
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
|
||||||
dtype=noise_mask.dtype,
|
dtype=noise_mask.dtype,
|
||||||
device=noise_mask.device,
|
device=noise_mask.device,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user