fix: address review feedback for guide attention mask

Replace ValueError with logger.warning when pixel_mask differs across
batch elements, and add assertion to validate pre_filter_count
partitioning against kf_grid_mask length.
This commit is contained in:
tavi 2026-02-21 23:15:51 +00:00 committed by tavihalperin
parent d6d8ebb601
commit 38e76f7b09

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from enum import Enum
import functools
import logging
import math
from typing import Dict, Optional, Tuple
@ -14,6 +15,8 @@ import comfy.ldm.common_dit
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
logger = logging.getLogger(__name__)
def _log_base(x, base):
return np.log(x) / np.log(base)
@ -912,6 +915,11 @@ class LTXVModel(LTXBaseModel):
# their pre_filter_counts partition the kf_grid_mask.
guide_entries = kwargs.get("guide_attention_entries", None)
if guide_entries:
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
assert total_pfc == len(kf_grid_mask), (
f"guide pre_filter_counts ({total_pfc}) != "
f"keyframe grid mask length ({len(kf_grid_mask)})"
)
resolved_entries = []
offset = 0
for entry in guide_entries:
@ -998,10 +1006,11 @@ class LTXVModel(LTXBaseModel):
ref = per_token[0]
for bi in range(1, per_token.shape[0]):
if not torch.equal(ref, per_token[bi]):
raise ValueError(
logger.warning(
"pixel_mask differs across batch elements; "
"per-sample pixel masks are not supported."
"using first element only."
)
break
per_token = per_token[:1]
# `surviving` is the post-grid_mask token count.
# Clamp to surviving to handle any mismatch safely.