mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-04 00:37:32 +08:00
feat: per-guide attention strength control in self-attention (#12518)
Implements per-guide attention attenuation via log-space additive bias in self-attention. Each guide reference tracks its own strength and optional spatial mask in conditioning metadata (guide_attention_entries).
This commit is contained in:
parent
907e5dcbbf
commit
a4522017c5
@ -218,7 +218,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
run_vx = transformer_options.get("run_vx", True)
|
run_vx = transformer_options.get("run_vx", True)
|
||||||
run_ax = transformer_options.get("run_ax", True)
|
run_ax = transformer_options.get("run_ax", True)
|
||||||
@ -234,7 +234,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||||
del vshift_msa, vscale_msa
|
del vshift_msa, vscale_msa
|
||||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||||
del norm_vx
|
del norm_vx
|
||||||
# video cross-attention
|
# video cross-attention
|
||||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||||
@ -726,7 +726,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||||
|
|
||||||
def _process_transformer_blocks(
|
def _process_transformer_blocks(
|
||||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
|
||||||
):
|
):
|
||||||
vx = x[0]
|
vx = x[0]
|
||||||
ax = x[1]
|
ax = x[1]
|
||||||
@ -770,6 +770,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||||
transformer_options=args["transformer_options"],
|
transformer_options=args["transformer_options"],
|
||||||
|
self_attention_mask=args.get("self_attention_mask"),
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -790,6 +791,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||||
"transformer_options": transformer_options,
|
"transformer_options": transformer_options,
|
||||||
|
"self_attention_mask": self_attention_mask,
|
||||||
},
|
},
|
||||||
{"original_block": block_wrap},
|
{"original_block": block_wrap},
|
||||||
)
|
)
|
||||||
@ -811,6 +813,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
|
self_attention_mask=self_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [vx, ax]
|
return [vx, ax]
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
@ -14,6 +15,8 @@ import comfy.ldm.common_dit
|
|||||||
|
|
||||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def _log_base(x, base):
|
def _log_base(x, base):
|
||||||
return np.log(x) / np.log(base)
|
return np.log(x) / np.log(base)
|
||||||
|
|
||||||
@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||||
|
|
||||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||||
x.addcmul_(attn1_input, gate_msa)
|
x.addcmul_(attn1_input, gate_msa)
|
||||||
del attn1_input
|
del attn1_input
|
||||||
|
|
||||||
@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
"""Process input data. Must be implemented by subclasses."""
|
"""Process input data. Must be implemented by subclasses."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||||
|
"""Build self-attention mask for per-guide attention attenuation.
|
||||||
|
|
||||||
|
Base implementation returns None (no attenuation). Subclasses that
|
||||||
|
support guide-based attention control should override this.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
|
||||||
"""Process transformer blocks. Must be implemented by subclasses."""
|
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||||
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||||
|
|
||||||
|
# Build self-attention mask for per-guide attenuation
|
||||||
|
self_attention_mask = self._build_guide_self_attention_mask(
|
||||||
|
x, transformer_options, merged_args
|
||||||
|
)
|
||||||
|
|
||||||
# Process transformer blocks
|
# Process transformer blocks
|
||||||
x = self._process_transformer_blocks(
|
x = self._process_transformer_blocks(
|
||||||
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
x, context, attention_mask, timestep, pe,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
self_attention_mask=self_attention_mask,
|
||||||
|
**merged_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process output
|
# Process output
|
||||||
@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel):
|
|||||||
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||||
|
|
||||||
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||||
|
|
||||||
|
# Compute per-guide surviving token counts from guide_attention_entries.
|
||||||
|
# Each entry tracks one guide reference; they are appended in order and
|
||||||
|
# their pre_filter_counts partition the kf_grid_mask.
|
||||||
|
guide_entries = kwargs.get("guide_attention_entries", None)
|
||||||
|
if guide_entries:
|
||||||
|
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
|
||||||
|
if total_pfc != len(kf_grid_mask):
|
||||||
|
raise ValueError(
|
||||||
|
f"guide pre_filter_counts ({total_pfc}) != "
|
||||||
|
f"keyframe grid mask length ({len(kf_grid_mask)})"
|
||||||
|
)
|
||||||
|
resolved_entries = []
|
||||||
|
offset = 0
|
||||||
|
for entry in guide_entries:
|
||||||
|
pfc = entry["pre_filter_count"]
|
||||||
|
entry_mask = kf_grid_mask[offset:offset + pfc]
|
||||||
|
surviving = int(entry_mask.sum().item())
|
||||||
|
resolved_entries.append({
|
||||||
|
**entry,
|
||||||
|
"surviving_count": surviving,
|
||||||
|
})
|
||||||
|
offset += pfc
|
||||||
|
additional_args["resolved_guide_entries"] = resolved_entries
|
||||||
|
|
||||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||||
|
|
||||||
|
# Total surviving guide tokens (all guides)
|
||||||
|
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
||||||
|
|
||||||
x = self.patchify_proj(x)
|
x = self.patchify_proj(x)
|
||||||
return x, pixel_coords, additional_args
|
return x, pixel_coords, additional_args
|
||||||
|
|
||||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||||
|
"""Build self-attention mask for per-guide attention attenuation.
|
||||||
|
|
||||||
|
Reads resolved_guide_entries from merged_args (computed in _process_input)
|
||||||
|
to build a log-space additive bias mask that attenuates noisy ↔ guide
|
||||||
|
attention for each guide reference independently.
|
||||||
|
|
||||||
|
Returns None if no attenuation is needed (all strengths == 1.0 and no
|
||||||
|
spatial masks, or no guide tokens).
|
||||||
|
"""
|
||||||
|
if isinstance(x, list):
|
||||||
|
# AV model: x = [vx, ax]; use vx for token count and device
|
||||||
|
total_tokens = x[0].shape[1]
|
||||||
|
device = x[0].device
|
||||||
|
dtype = x[0].dtype
|
||||||
|
else:
|
||||||
|
total_tokens = x.shape[1]
|
||||||
|
device = x.device
|
||||||
|
dtype = x.dtype
|
||||||
|
|
||||||
|
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
|
||||||
|
if num_guide_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
resolved_entries = merged_args.get("resolved_guide_entries", None)
|
||||||
|
if not resolved_entries:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if any attenuation is actually needed
|
||||||
|
needs_attenuation = any(
|
||||||
|
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
||||||
|
for e in resolved_entries
|
||||||
|
)
|
||||||
|
if not needs_attenuation:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build per-guide-token weights for all tracked guide tokens.
|
||||||
|
# Guides are appended in order at the end of the sequence.
|
||||||
|
guide_start = total_tokens - num_guide_tokens
|
||||||
|
all_weights = []
|
||||||
|
total_tracked = 0
|
||||||
|
|
||||||
|
for entry in resolved_entries:
|
||||||
|
surviving = entry["surviving_count"]
|
||||||
|
if surviving == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
strength = entry["strength"]
|
||||||
|
pixel_mask = entry.get("pixel_mask")
|
||||||
|
latent_shape = entry.get("latent_shape")
|
||||||
|
|
||||||
|
if pixel_mask is not None and latent_shape is not None:
|
||||||
|
f_lat, h_lat, w_lat = latent_shape
|
||||||
|
per_token = self._downsample_mask_to_latent(
|
||||||
|
pixel_mask.to(device=device, dtype=dtype),
|
||||||
|
f_lat, h_lat, w_lat,
|
||||||
|
)
|
||||||
|
# per_token shape: (B, f_lat*h_lat*w_lat).
|
||||||
|
# Collapse batch dim — the mask is assumed identical across the
|
||||||
|
# batch; validate and take the first element to get (1, tokens).
|
||||||
|
if per_token.shape[0] > 1:
|
||||||
|
ref = per_token[0]
|
||||||
|
for bi in range(1, per_token.shape[0]):
|
||||||
|
if not torch.equal(ref, per_token[bi]):
|
||||||
|
logger.warning(
|
||||||
|
"pixel_mask differs across batch elements; "
|
||||||
|
"using first element only."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
per_token = per_token[:1]
|
||||||
|
# `surviving` is the post-grid_mask token count.
|
||||||
|
# Clamp to surviving to handle any mismatch safely.
|
||||||
|
n_weights = min(per_token.shape[1], surviving)
|
||||||
|
weights = per_token[:, :n_weights] * strength # (1, n_weights)
|
||||||
|
else:
|
||||||
|
weights = torch.full(
|
||||||
|
(1, surviving), strength, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
all_weights.append(weights)
|
||||||
|
total_tracked += weights.shape[1]
|
||||||
|
|
||||||
|
if not all_weights:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Concatenate per-token weights for all tracked guides
|
||||||
|
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
||||||
|
|
||||||
|
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
||||||
|
if (tracked_weights >= 1.0).all():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build the mask: guide tokens are at the end of the sequence.
|
||||||
|
# Tracked guides come first (in order), untracked follow.
|
||||||
|
return self._build_self_attention_mask(
|
||||||
|
total_tokens, num_guide_tokens, total_tracked,
|
||||||
|
tracked_weights, guide_start, device, dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
||||||
|
"""Downsample a pixel-space mask to per-token latent weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
|
||||||
|
f_lat: Number of latent frames (pre-dilation original count).
|
||||||
|
h_lat: Latent height (pre-dilation original height).
|
||||||
|
w_lat: Latent width (pre-dilation original width).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(B, F_lat * H_lat * W_lat) flattened per-token weights.
|
||||||
|
"""
|
||||||
|
b = mask.shape[0]
|
||||||
|
f_pix = mask.shape[2]
|
||||||
|
|
||||||
|
# Spatial downsampling: area interpolation per frame
|
||||||
|
spatial_down = torch.nn.functional.interpolate(
|
||||||
|
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
|
||||||
|
size=(h_lat, w_lat),
|
||||||
|
mode="area",
|
||||||
|
)
|
||||||
|
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
|
||||||
|
|
||||||
|
# Temporal downsampling: first pixel frame maps to first latent frame,
|
||||||
|
# remaining pixel frames are averaged in groups for causal temporal structure.
|
||||||
|
first_frame = spatial_down[:, :, :1, :, :]
|
||||||
|
if f_pix > 1 and f_lat > 1:
|
||||||
|
remaining_pix = f_pix - 1
|
||||||
|
remaining_lat = f_lat - 1
|
||||||
|
t = remaining_pix // remaining_lat
|
||||||
|
if t < 1:
|
||||||
|
# Fewer pixel frames than latent frames — upsample by repeating
|
||||||
|
# the available pixel frames via nearest interpolation.
|
||||||
|
rest_flat = rearrange(
|
||||||
|
spatial_down[:, :, 1:, :, :],
|
||||||
|
"b 1 f h w -> (b h w) 1 f",
|
||||||
|
)
|
||||||
|
rest_up = torch.nn.functional.interpolate(
|
||||||
|
rest_flat, size=remaining_lat, mode="nearest",
|
||||||
|
)
|
||||||
|
rest = rearrange(
|
||||||
|
rest_up, "(b h w) 1 f -> b 1 f h w",
|
||||||
|
b=b, h=h_lat, w=w_lat,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Trim trailing pixel frames that don't fill a complete group
|
||||||
|
usable = remaining_lat * t
|
||||||
|
rest = rearrange(
|
||||||
|
spatial_down[:, :, 1:1 + usable, :, :],
|
||||||
|
"b 1 (f t) h w -> b 1 f t h w",
|
||||||
|
t=t,
|
||||||
|
)
|
||||||
|
rest = rest.mean(dim=3)
|
||||||
|
latent_mask = torch.cat([first_frame, rest], dim=2)
|
||||||
|
elif f_lat > 1:
|
||||||
|
# Single pixel frame but multiple latent frames — repeat the
|
||||||
|
# single frame across all latent frames.
|
||||||
|
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
|
||||||
|
else:
|
||||||
|
latent_mask = first_frame
|
||||||
|
|
||||||
|
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
||||||
|
tracked_weights, guide_start, device, dtype):
|
||||||
|
"""Build a log-space additive self-attention bias mask.
|
||||||
|
|
||||||
|
Attenuates attention between noisy tokens and tracked guide tokens.
|
||||||
|
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_tokens: Total sequence length.
|
||||||
|
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
||||||
|
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
||||||
|
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
||||||
|
guide_start: Index where guide tokens begin in the sequence.
|
||||||
|
device: Target device.
|
||||||
|
dtype: Target dtype.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(1, 1, total_tokens, total_tokens) additive bias mask.
|
||||||
|
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
||||||
|
"""
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
||||||
|
tracked_end = guide_start + tracked_count
|
||||||
|
|
||||||
|
# Convert weights to log-space bias
|
||||||
|
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
||||||
|
log_w = torch.full_like(w, finfo.min)
|
||||||
|
positive_mask = w > 0
|
||||||
|
if positive_mask.any():
|
||||||
|
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
||||||
|
|
||||||
|
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
||||||
|
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
||||||
|
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
||||||
|
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
||||||
"""Process transformer blocks for LTXV."""
|
"""Process transformer blocks for LTXV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel):
|
|||||||
|
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(
|
x = block(
|
||||||
@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel):
|
|||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
pe=pe,
|
pe=pe,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
|
self_attention_mask=self_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -65,6 +65,42 @@ from typing import TYPE_CHECKING
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
|
||||||
|
|
||||||
|
class _CONDGuideEntries(comfy.conds.CONDConstant):
|
||||||
|
"""CONDConstant subclass that safely compares guide_attention_entries.
|
||||||
|
|
||||||
|
guide_attention_entries may contain ``pixel_mask`` tensors. The default
|
||||||
|
``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError``
|
||||||
|
on tensors. This subclass performs a structural comparison instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def can_concat(self, other):
|
||||||
|
if not isinstance(other, _CONDGuideEntries):
|
||||||
|
return False
|
||||||
|
a, b = self.cond, other.cond
|
||||||
|
if len(a) != len(b):
|
||||||
|
return False
|
||||||
|
for ea, eb in zip(a, b):
|
||||||
|
if ea["pre_filter_count"] != eb["pre_filter_count"]:
|
||||||
|
return False
|
||||||
|
if ea["strength"] != eb["strength"]:
|
||||||
|
return False
|
||||||
|
if ea.get("latent_shape") != eb.get("latent_shape"):
|
||||||
|
return False
|
||||||
|
a_has = ea.get("pixel_mask") is not None
|
||||||
|
b_has = eb.get("pixel_mask") is not None
|
||||||
|
if a_has != b_has:
|
||||||
|
return False
|
||||||
|
if a_has:
|
||||||
|
pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"]
|
||||||
|
if pm_a is not pm_b:
|
||||||
|
if (pm_a.shape != pm_b.shape
|
||||||
|
or pm_a.device != pm_b.device
|
||||||
|
or pm_a.dtype != pm_b.dtype
|
||||||
|
or not torch.equal(pm_a, pm_b)):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
EPS = 1
|
EPS = 1
|
||||||
V_PREDICTION = 2
|
V_PREDICTION = 2
|
||||||
@ -974,6 +1010,10 @@ class LTXV(BaseModel):
|
|||||||
if keyframe_idxs is not None:
|
if keyframe_idxs is not None:
|
||||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||||
|
|
||||||
|
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||||
|
if guide_attention_entries is not None:
|
||||||
|
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
@ -1026,6 +1066,10 @@ class LTXAV(BaseModel):
|
|||||||
if latent_shapes is not None:
|
if latent_shapes is not None:
|
||||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||||
|
|
||||||
|
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||||
|
if guide_attention_entries is not None:
|
||||||
|
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||||
|
|||||||
@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
|||||||
generate = execute # TODO: remove
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
|
||||||
|
"""Append a guide_attention_entry to both positive and negative conditioning.
|
||||||
|
|
||||||
|
Each entry tracks one guide reference for per-reference attention control.
|
||||||
|
Entries are derived independently from each conditioning to avoid cross-contamination.
|
||||||
|
"""
|
||||||
|
new_entry = {
|
||||||
|
"pre_filter_count": pre_filter_count,
|
||||||
|
"strength": strength,
|
||||||
|
"pixel_mask": None,
|
||||||
|
"latent_shape": latent_shape,
|
||||||
|
}
|
||||||
|
results = []
|
||||||
|
for cond in (positive, negative):
|
||||||
|
# Read existing entries from this specific conditioning
|
||||||
|
existing = []
|
||||||
|
for t in cond:
|
||||||
|
found = t[1].get("guide_attention_entries", None)
|
||||||
|
if found is not None:
|
||||||
|
existing = found
|
||||||
|
break
|
||||||
|
# Shallow copy and append (no deepcopy needed — entries contain
|
||||||
|
# only scalars and None for pixel_mask at this call site).
|
||||||
|
entries = [*existing, new_entry]
|
||||||
|
results.append(node_helpers.conditioning_set_values(
|
||||||
|
cond, {"guide_attention_entries": entries}
|
||||||
|
))
|
||||||
|
return results[0], results[1]
|
||||||
|
|
||||||
|
|
||||||
def conditioning_get_any_value(conditioning, key, default=None):
|
def conditioning_get_any_value(conditioning, key, default=None):
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
if key in t[1]:
|
if key in t[1]:
|
||||||
@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
scale_factors,
|
scale_factors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track this guide for per-reference attention control.
|
||||||
|
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
|
||||||
|
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
|
||||||
|
positive, negative = _append_guide_attention_entry(
|
||||||
|
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
||||||
|
)
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||||
|
|
||||||
generate = execute # TODO: remove
|
generate = execute # TODO: remove
|
||||||
@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode):
|
|||||||
latent_image = latent_image[:, :, :-num_keyframes]
|
latent_image = latent_image[:, :, :-num_keyframes]
|
||||||
noise_mask = noise_mask[:, :, :-num_keyframes]
|
noise_mask = noise_mask[:, :, :-num_keyframes]
|
||||||
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
|
positive = node_helpers.conditioning_set_values(positive, {
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
|
"keyframe_idxs": None,
|
||||||
|
"guide_attention_entries": None,
|
||||||
|
})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {
|
||||||
|
"keyframe_idxs": None,
|
||||||
|
"guide_attention_entries": None,
|
||||||
|
})
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user