feat: per-guide attention strength control in self-attention (#12518)

Implements per-guide attention attenuation via log-space additive bias
in self-attention. Each guide reference tracks its own strength and
optional spatial mask in conditioning metadata (guide_attention_entries).
This commit is contained in:
Tavi Halperin 2026-02-26 08:25:23 +02:00 committed by GitHub
parent 907e5dcbbf
commit a4522017c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 352 additions and 12 deletions

View File

@ -218,7 +218,7 @@ class BasicAVTransformerBlock(nn.Module):
def forward( def forward(
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True) run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True) run_ax = transformer_options.get("run_ax", True)
@ -234,7 +234,7 @@ class BasicAVTransformerBlock(nn.Module):
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2))) vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
del vshift_msa, vscale_msa del vshift_msa, vscale_msa
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
del norm_vx del norm_vx
# video cross-attention # video cross-attention
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
@ -726,7 +726,7 @@ class LTXAVModel(LTXVModel):
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
def _process_transformer_blocks( def _process_transformer_blocks(
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
): ):
vx = x[0] vx = x[0]
ax = x[1] ax = x[1]
@ -770,6 +770,7 @@ class LTXAVModel(LTXVModel):
v_cross_gate_timestep=args["v_cross_gate_timestep"], v_cross_gate_timestep=args["v_cross_gate_timestep"],
a_cross_gate_timestep=args["a_cross_gate_timestep"], a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"], transformer_options=args["transformer_options"],
self_attention_mask=args.get("self_attention_mask"),
) )
return out return out
@ -790,6 +791,7 @@ class LTXAVModel(LTXVModel):
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options, "transformer_options": transformer_options,
"self_attention_mask": self_attention_mask,
}, },
{"original_block": block_wrap}, {"original_block": block_wrap},
) )
@ -811,6 +813,7 @@ class LTXAVModel(LTXVModel):
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options, transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
) )
return [vx, ax] return [vx, ax]

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
import functools import functools
import logging
import math import math
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
@ -14,6 +15,8 @@ import comfy.ldm.common_dit
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
logger = logging.getLogger(__name__)
def _log_base(x, base): def _log_base(x, base):
return np.log(x) / np.log(base) return np.log(x) / np.log(base)
@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
attn1_input = comfy.ldm.common_dit.rms_norm(x) attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa) x.addcmul_(attn1_input, gate_msa)
del attn1_input del attn1_input
@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
"""Process input data. Must be implemented by subclasses.""" """Process input data. Must be implemented by subclasses."""
pass pass
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
"""Build self-attention mask for per-guide attention attenuation.
Base implementation returns None (no attenuation). Subclasses that
support guide-based attention control should override this.
"""
return None
@abstractmethod @abstractmethod
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs): def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
"""Process transformer blocks. Must be implemented by subclasses.""" """Process transformer blocks. Must be implemented by subclasses."""
pass pass
@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC):
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
# Build self-attention mask for per-guide attenuation
self_attention_mask = self._build_guide_self_attention_mask(
x, transformer_options, merged_args
)
# Process transformer blocks # Process transformer blocks
x = self._process_transformer_blocks( x = self._process_transformer_blocks(
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args x, context, attention_mask, timestep, pe,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
**merged_args,
) )
# Process output # Process output
@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel):
pixel_coords = pixel_coords[:, :, grid_mask, ...] pixel_coords = pixel_coords[:, :, grid_mask, ...]
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
# Compute per-guide surviving token counts from guide_attention_entries.
# Each entry tracks one guide reference; they are appended in order and
# their pre_filter_counts partition the kf_grid_mask.
guide_entries = kwargs.get("guide_attention_entries", None)
if guide_entries:
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
if total_pfc != len(kf_grid_mask):
raise ValueError(
f"guide pre_filter_counts ({total_pfc}) != "
f"keyframe grid mask length ({len(kf_grid_mask)})"
)
resolved_entries = []
offset = 0
for entry in guide_entries:
pfc = entry["pre_filter_count"]
entry_mask = kf_grid_mask[offset:offset + pfc]
surviving = int(entry_mask.sum().item())
resolved_entries.append({
**entry,
"surviving_count": surviving,
})
offset += pfc
additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
x = self.patchify_proj(x) x = self.patchify_proj(x)
return x, pixel_coords, additional_args return x, pixel_coords, additional_args
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs): def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
"""Build self-attention mask for per-guide attention attenuation.
Reads resolved_guide_entries from merged_args (computed in _process_input)
to build a log-space additive bias mask that attenuates noisy guide
attention for each guide reference independently.
Returns None if no attenuation is needed (all strengths == 1.0 and no
spatial masks, or no guide tokens).
"""
if isinstance(x, list):
# AV model: x = [vx, ax]; use vx for token count and device
total_tokens = x[0].shape[1]
device = x[0].device
dtype = x[0].dtype
else:
total_tokens = x.shape[1]
device = x.device
dtype = x.dtype
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
if num_guide_tokens == 0:
return None
resolved_entries = merged_args.get("resolved_guide_entries", None)
if not resolved_entries:
return None
# Check if any attenuation is actually needed
needs_attenuation = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries
)
if not needs_attenuation:
return None
# Build per-guide-token weights for all tracked guide tokens.
# Guides are appended in order at the end of the sequence.
guide_start = total_tokens - num_guide_tokens
all_weights = []
total_tracked = 0
for entry in resolved_entries:
surviving = entry["surviving_count"]
if surviving == 0:
continue
strength = entry["strength"]
pixel_mask = entry.get("pixel_mask")
latent_shape = entry.get("latent_shape")
if pixel_mask is not None and latent_shape is not None:
f_lat, h_lat, w_lat = latent_shape
per_token = self._downsample_mask_to_latent(
pixel_mask.to(device=device, dtype=dtype),
f_lat, h_lat, w_lat,
)
# per_token shape: (B, f_lat*h_lat*w_lat).
# Collapse batch dim — the mask is assumed identical across the
# batch; validate and take the first element to get (1, tokens).
if per_token.shape[0] > 1:
ref = per_token[0]
for bi in range(1, per_token.shape[0]):
if not torch.equal(ref, per_token[bi]):
logger.warning(
"pixel_mask differs across batch elements; "
"using first element only."
)
break
per_token = per_token[:1]
# `surviving` is the post-grid_mask token count.
# Clamp to surviving to handle any mismatch safely.
n_weights = min(per_token.shape[1], surviving)
weights = per_token[:, :n_weights] * strength # (1, n_weights)
else:
weights = torch.full(
(1, surviving), strength, device=device, dtype=dtype
)
all_weights.append(weights)
total_tracked += weights.shape[1]
if not all_weights:
return None
# Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
if (tracked_weights >= 1.0).all():
return None
# Build the mask: guide tokens are at the end of the sequence.
# Tracked guides come first (in order), untracked follow.
return self._build_self_attention_mask(
total_tokens, num_guide_tokens, total_tracked,
tracked_weights, guide_start, device, dtype,
)
@staticmethod
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
"""Downsample a pixel-space mask to per-token latent weights.
Args:
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
f_lat: Number of latent frames (pre-dilation original count).
h_lat: Latent height (pre-dilation original height).
w_lat: Latent width (pre-dilation original width).
Returns:
(B, F_lat * H_lat * W_lat) flattened per-token weights.
"""
b = mask.shape[0]
f_pix = mask.shape[2]
# Spatial downsampling: area interpolation per frame
spatial_down = torch.nn.functional.interpolate(
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
size=(h_lat, w_lat),
mode="area",
)
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
# Temporal downsampling: first pixel frame maps to first latent frame,
# remaining pixel frames are averaged in groups for causal temporal structure.
first_frame = spatial_down[:, :, :1, :, :]
if f_pix > 1 and f_lat > 1:
remaining_pix = f_pix - 1
remaining_lat = f_lat - 1
t = remaining_pix // remaining_lat
if t < 1:
# Fewer pixel frames than latent frames — upsample by repeating
# the available pixel frames via nearest interpolation.
rest_flat = rearrange(
spatial_down[:, :, 1:, :, :],
"b 1 f h w -> (b h w) 1 f",
)
rest_up = torch.nn.functional.interpolate(
rest_flat, size=remaining_lat, mode="nearest",
)
rest = rearrange(
rest_up, "(b h w) 1 f -> b 1 f h w",
b=b, h=h_lat, w=w_lat,
)
else:
# Trim trailing pixel frames that don't fill a complete group
usable = remaining_lat * t
rest = rearrange(
spatial_down[:, :, 1:1 + usable, :, :],
"b 1 (f t) h w -> b 1 f t h w",
t=t,
)
rest = rest.mean(dim=3)
latent_mask = torch.cat([first_frame, rest], dim=2)
elif f_lat > 1:
# Single pixel frame but multiple latent frames — repeat the
# single frame across all latent frames.
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
else:
latent_mask = first_frame
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
@staticmethod
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
tracked_weights, guide_start, device, dtype):
"""Build a log-space additive self-attention bias mask.
Attenuates attention between noisy tokens and tracked guide tokens.
Untracked guide tokens (at the end of the guide portion) keep full attention.
Args:
total_tokens: Total sequence length.
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
tracked_count: Number of tracked guide tokens (first in the guide portion).
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
guide_start: Index where guide tokens begin in the sequence.
device: Target device.
dtype: Target dtype.
Returns:
(1, 1, total_tokens, total_tokens) additive bias mask.
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
"""
finfo = torch.finfo(dtype)
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
tracked_end = guide_start + tracked_count
# Convert weights to log-space bias
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
log_w = torch.full_like(w, finfo.min)
positive_mask = w > 0
if positive_mask.any():
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
# noisy → tracked guides: each noisy row gets the same per-guide weight
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
return mask
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
"""Process transformer blocks for LTXV.""" """Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel):
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
x = out["img"] x = out["img"]
else: else:
x = block( x = block(
@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel):
timestep=timestep, timestep=timestep,
pe=pe, pe=pe,
transformer_options=transformer_options, transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
) )
return x return x

View File

@ -65,6 +65,42 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
class _CONDGuideEntries(comfy.conds.CONDConstant):
"""CONDConstant subclass that safely compares guide_attention_entries.
guide_attention_entries may contain ``pixel_mask`` tensors. The default
``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError``
on tensors. This subclass performs a structural comparison instead.
"""
def can_concat(self, other):
if not isinstance(other, _CONDGuideEntries):
return False
a, b = self.cond, other.cond
if len(a) != len(b):
return False
for ea, eb in zip(a, b):
if ea["pre_filter_count"] != eb["pre_filter_count"]:
return False
if ea["strength"] != eb["strength"]:
return False
if ea.get("latent_shape") != eb.get("latent_shape"):
return False
a_has = ea.get("pixel_mask") is not None
b_has = eb.get("pixel_mask") is not None
if a_has != b_has:
return False
if a_has:
pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"]
if pm_a is not pm_b:
if (pm_a.shape != pm_b.shape
or pm_a.device != pm_b.device
or pm_a.dtype != pm_b.dtype
or not torch.equal(pm_a, pm_b)):
return False
return True
class ModelType(Enum): class ModelType(Enum):
EPS = 1 EPS = 1
V_PREDICTION = 2 V_PREDICTION = 2
@ -974,6 +1010,10 @@ class LTXV(BaseModel):
if keyframe_idxs is not None: if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
return out return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
@ -1026,6 +1066,10 @@ class LTXAV(BaseModel):
if latent_shapes is not None: if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
return out return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):

View File

@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode):
generate = execute # TODO: remove generate = execute # TODO: remove
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
"""Append a guide_attention_entry to both positive and negative conditioning.
Each entry tracks one guide reference for per-reference attention control.
Entries are derived independently from each conditioning to avoid cross-contamination.
"""
new_entry = {
"pre_filter_count": pre_filter_count,
"strength": strength,
"pixel_mask": None,
"latent_shape": latent_shape,
}
results = []
for cond in (positive, negative):
# Read existing entries from this specific conditioning
existing = []
for t in cond:
found = t[1].get("guide_attention_entries", None)
if found is not None:
existing = found
break
# Shallow copy and append (no deepcopy needed — entries contain
# only scalars and None for pixel_mask at this call site).
entries = [*existing, new_entry]
results.append(node_helpers.conditioning_set_values(
cond, {"guide_attention_entries": entries}
))
return results[0], results[1]
def conditioning_get_any_value(conditioning, key, default=None): def conditioning_get_any_value(conditioning, key, default=None):
for t in conditioning: for t in conditioning:
if key in t[1]: if key in t[1]:
@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode):
scale_factors, scale_factors,
) )
# Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
generate = execute # TODO: remove generate = execute # TODO: remove
@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode):
latent_image = latent_image[:, :, :-num_keyframes] latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes]
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) positive = node_helpers.conditioning_set_values(positive, {
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) "keyframe_idxs": None,
"guide_attention_entries": None,
})
negative = node_helpers.conditioning_set_values(negative, {
"keyframe_idxs": None,
"guide_attention_entries": None,
})
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})