diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 3fb87b4a3..bc09fb77e 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -22,26 +22,25 @@ class CompressedTimestep: """Store video timestep embeddings in compressed form using per-frame indexing.""" __slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim') - def __init__(self, tensor: torch.Tensor, patches_per_frame: int): + def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False): """ - tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame - patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression + tensor: [batch, num_tokens, feature_dim] (per-token, default) or + [batch, num_frames, feature_dim] (per_frame=True, already compressed). + patches_per_frame: spatial patches per frame; pass None to disable compression. """ - self.batch_size, num_tokens, self.feature_dim = tensor.shape - - # Check if compression is valid (num_tokens must be divisible by patches_per_frame) - if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: + self.batch_size, n, self.feature_dim = tensor.shape + if per_frame: self.patches_per_frame = patches_per_frame - self.num_frames = num_tokens // patches_per_frame - - # Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame - # All patches in a frame are identical, so we only keep the first one - reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim) - self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim] + self.num_frames = n + self.data = tensor + elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0: + self.patches_per_frame = patches_per_frame + self.num_frames = n // patches_per_frame + # All patches in a frame are identical — keep only the first. + self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous() else: - # Not divisible or too small - store directly without compression self.patches_per_frame = 1 - self.num_frames = num_tokens + self.num_frames = n self.data = tensor def expand(self): @@ -716,32 +715,35 @@ class LTXAVModel(LTXVModel): def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): """Prepare timestep embeddings.""" - # TODO: some code reuse is needed here. grid_mask = kwargs.get("grid_mask", None) - if grid_mask is not None: - timestep = timestep[:, grid_mask] - - timestep_scaled = timestep * self.timestep_scale_multiplier - - v_timestep, v_embedded_timestep = self.adaln_single( - timestep_scaled.flatten(), - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=hidden_dtype, - ) - - # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width] - # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width orig_shape = kwargs.get("orig_shape") has_spatial_mask = kwargs.get("has_spatial_mask", None) v_patches_per_frame = None if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5: - # orig_shape[3] = height, orig_shape[4] = width (in latent space) v_patches_per_frame = orig_shape[3] * orig_shape[4] - # Reshape to [batch_size, num_tokens, dim] and compress for storage - v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) - v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) + # Used by compute_prompt_timestep and the audio cross-attention paths. + timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier + + # When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token + per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0 + if per_frame_path: + per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0] + if grid_mask is not None: + # All-or-nothing per frame when has_spatial_mask=False. + per_frame = per_frame[:, grid_mask[::v_patches_per_frame]] + ts_input = per_frame * self.timestep_scale_multiplier + else: + ts_input = timestep_scaled + + v_timestep, v_embedded_timestep = self.adaln_single( + ts_input.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path) + v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path) v_prompt_timestep = compute_prompt_timestep( self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index bfbc08357..80a3f08d7 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -358,6 +358,63 @@ def apply_split_rotary_emb(input_tensor, cos, sin): return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output +class GuideAttentionMask: + """Holds the two per-group masks for LTXV guide self-attention. + _attention_with_guide_mask splits queries into noisy and tracked-guide + groups, so the largest mask is (1, 1, tracked_count, T). + """ + __slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask") + + def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights): + device = tracked_weights.device + dtype = tracked_weights.dtype + finfo = torch.finfo(dtype) + + pos = tracked_weights > 0 + log_w = torch.full_like(tracked_weights, finfo.min) + log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny)) + + self.guide_start = guide_start + self.tracked_count = tracked_count + + self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype) + self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1) + + self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype) + self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1) + + def to(self, *args, **kwargs): + new = GuideAttentionMask.__new__(GuideAttentionMask) + new.guide_start = self.guide_start + new.tracked_count = self.tracked_count + new.noisy_mask = self.noisy_mask.to(*args, **kwargs) + new.tracked_mask = self.tracked_mask.to(*args, **kwargs) + return new + + +def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options): + """Apply the guide mask by partitioning Q into noisy and tracked-guide + groups, so each group needs only its own sub-mask. Avoids materializing + the (1,1,T,T) dense mask. + """ + guide_start = guide_mask.guide_start + tracked_end = guide_start + guide_mask.tracked_count + + out = torch.empty_like(q) + + out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention( + q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask, + attn_precision=attn_precision, transformer_options=transformer_options, + low_precision_attention=False, # sageattn mask support is unreliable + ) + out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention( + q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask, + attn_precision=attn_precision, transformer_options=transformer_options, + low_precision_attention=False, + ) + return out + + class CrossAttention(nn.Module): def __init__( self, @@ -412,8 +469,10 @@ class CrossAttention(nn.Module): if mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) + elif isinstance(mask, GuideAttentionMask): + out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) else: - out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) + out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options) # Apply per-head gating if enabled if self.to_gate_logits is not None: @@ -1063,7 +1122,9 @@ class LTXVModel(LTXBaseModel): additional_args["resolved_guide_entries"] = resolved_entries keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] - pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs + + if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving + pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs # Total surviving guide tokens (all guides) additional_args["num_guide_tokens"] = keyframe_idxs.shape[2] @@ -1099,12 +1160,12 @@ class LTXVModel(LTXBaseModel): if not resolved_entries: return None - # Check if any attenuation is actually needed - needs_attenuation = any( - e["strength"] < 1.0 or e.get("pixel_mask") is not None + # strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention. + needs_mask = any( + e["strength"] != 1.0 or e.get("pixel_mask") is not None for e in resolved_entries ) - if not needs_attenuation: + if not needs_mask: return None # Build per-guide-token weights for all tracked guide tokens. @@ -1159,16 +1220,11 @@ class LTXVModel(LTXBaseModel): # Concatenate per-token weights for all tracked guides tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked) - # Check if any weight is actually < 1.0 (otherwise no attenuation needed) - if (tracked_weights >= 1.0).all(): + # Skip when every weight is exactly 1.0 (additive bias would be 0). + if (tracked_weights == 1.0).all(): return None - # Build the mask: guide tokens are at the end of the sequence. - # Tracked guides come first (in order), untracked follow. - return self._build_self_attention_mask( - total_tokens, num_guide_tokens, total_tracked, - tracked_weights, guide_start, device, dtype, - ) + return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights) @staticmethod def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat): @@ -1234,45 +1290,6 @@ class LTXVModel(LTXBaseModel): return rearrange(latent_mask, "b 1 f h w -> b (f h w)") - @staticmethod - def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count, - tracked_weights, guide_start, device, dtype): - """Build a log-space additive self-attention bias mask. - - Attenuates attention between noisy tokens and tracked guide tokens. - Untracked guide tokens (at the end of the guide portion) keep full attention. - - Args: - total_tokens: Total sequence length. - num_guide_tokens: Total guide tokens (all guides) at end of sequence. - tracked_count: Number of tracked guide tokens (first in the guide portion). - tracked_weights: (1, tracked_count) tensor, values in [0, 1]. - guide_start: Index where guide tokens begin in the sequence. - device: Target device. - dtype: Target dtype. - - Returns: - (1, 1, total_tokens, total_tokens) additive bias mask. - 0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked. - """ - finfo = torch.finfo(dtype) - mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype) - tracked_end = guide_start + tracked_count - - # Convert weights to log-space bias - w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count) - log_w = torch.full_like(w, finfo.min) - positive_mask = w > 0 - if positive_mask.any(): - log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny)) - - # noisy → tracked guides: each noisy row gets the same per-guide weight - mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1) - # tracked guides → noisy: each guide row broadcasts its weight across noisy cols - mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1) - - return mask - def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index f1f4d5319..9eed90216 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -223,7 +223,7 @@ class LTXVAddGuide(io.ComfyNode): "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded " "down to the nearest multiple of 8. Negative values are counted from the end of the video.", ), - io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), ], outputs=[ io.Conditioning.Output(display_name="positive"), @@ -302,7 +302,7 @@ class LTXVAddGuide(io.ComfyNode): else: mask = torch.full( (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), - 1.0 - strength, + max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask dtype=noise_mask.dtype, device=noise_mask.device, ) @@ -322,7 +322,7 @@ class LTXVAddGuide(io.ComfyNode): mask = torch.full( (noise_mask.shape[0], 1, cond_length, 1, 1), - 1.0 - strength, + max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask dtype=noise_mask.dtype, device=noise_mask.device, )