From 38e76f7b09d28d991f23334ed81ed6c602905423 Mon Sep 17 00:00:00 2001 From: tavi Date: Sat, 21 Feb 2026 23:15:51 +0000 Subject: [PATCH] fix: address review feedback for guide attention mask Replace ValueError with logger.warning when pixel_mask differs across batch elements, and add assertion to validate pre_filter_count partitioning against kf_grid_mask length. --- comfy/ldm/lightricks/model.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 87057f2e7..e1a541a3b 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from enum import Enum import functools +import logging import math from typing import Dict, Optional, Tuple @@ -14,6 +15,8 @@ import comfy.ldm.common_dit from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords +logger = logging.getLogger(__name__) + def _log_base(x, base): return np.log(x) / np.log(base) @@ -912,6 +915,11 @@ class LTXVModel(LTXBaseModel): # their pre_filter_counts partition the kf_grid_mask. guide_entries = kwargs.get("guide_attention_entries", None) if guide_entries: + total_pfc = sum(e["pre_filter_count"] for e in guide_entries) + assert total_pfc == len(kf_grid_mask), ( + f"guide pre_filter_counts ({total_pfc}) != " + f"keyframe grid mask length ({len(kf_grid_mask)})" + ) resolved_entries = [] offset = 0 for entry in guide_entries: @@ -998,10 +1006,11 @@ class LTXVModel(LTXBaseModel): ref = per_token[0] for bi in range(1, per_token.shape[0]): if not torch.equal(ref, per_token[bi]): - raise ValueError( + logger.warning( "pixel_mask differs across batch elements; " - "per-sample pixel masks are not supported." + "using first element only." ) + break per_token = per_token[:1] # `surviving` is the post-grid_mask token count. # Clamp to surviving to handle any mismatch safely.