mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-12 01:07:30 +08:00
fix: address review feedback for guide attention mask
Replace ValueError with logger.warning when pixel_mask differs across batch elements, and add assertion to validate pre_filter_count partitioning against kf_grid_mask length.
This commit is contained in:
parent
d6d8ebb601
commit
38e76f7b09
@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
@ -14,6 +15,8 @@ import comfy.ldm.common_dit
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _log_base(x, base):
|
||||
return np.log(x) / np.log(base)
|
||||
|
||||
@ -912,6 +915,11 @@ class LTXVModel(LTXBaseModel):
|
||||
# their pre_filter_counts partition the kf_grid_mask.
|
||||
guide_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_entries:
|
||||
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
|
||||
assert total_pfc == len(kf_grid_mask), (
|
||||
f"guide pre_filter_counts ({total_pfc}) != "
|
||||
f"keyframe grid mask length ({len(kf_grid_mask)})"
|
||||
)
|
||||
resolved_entries = []
|
||||
offset = 0
|
||||
for entry in guide_entries:
|
||||
@ -998,10 +1006,11 @@ class LTXVModel(LTXBaseModel):
|
||||
ref = per_token[0]
|
||||
for bi in range(1, per_token.shape[0]):
|
||||
if not torch.equal(ref, per_token[bi]):
|
||||
raise ValueError(
|
||||
logger.warning(
|
||||
"pixel_mask differs across batch elements; "
|
||||
"per-sample pixel masks are not supported."
|
||||
"using first element only."
|
||||
)
|
||||
break
|
||||
per_token = per_token[:1]
|
||||
# `surviving` is the post-grid_mask token count.
|
||||
# Clamp to surviving to handle any mismatch safely.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user