mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-06 01:37:45 +08:00
Cleanups to the last PR. (#12646)
This commit is contained in:
parent
a4522017c5
commit
8a4d85c708
@ -4,6 +4,25 @@ import comfy.utils
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def is_equal(x, y):
|
||||||
|
if torch.is_tensor(x) and torch.is_tensor(y):
|
||||||
|
return torch.equal(x, y)
|
||||||
|
elif isinstance(x, dict) and isinstance(y, dict):
|
||||||
|
if x.keys() != y.keys():
|
||||||
|
return False
|
||||||
|
return all(is_equal(x[k], y[k]) for k in x)
|
||||||
|
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||||
|
if type(x) is not type(y) or len(x) != len(y):
|
||||||
|
return False
|
||||||
|
return all(is_equal(a, b) for a, b in zip(x, y))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return x == y
|
||||||
|
except Exception:
|
||||||
|
logging.warning("comparison issue with COND")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class CONDRegular:
|
class CONDRegular:
|
||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
@ -84,7 +103,7 @@ class CONDConstant(CONDRegular):
|
|||||||
return self._copy_with(self.cond)
|
return self._copy_with(self.cond)
|
||||||
|
|
||||||
def can_concat(self, other):
|
def can_concat(self, other):
|
||||||
if self.cond != other.cond:
|
if not is_equal(self.cond, other.cond):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@ -65,42 +65,6 @@ from typing import TYPE_CHECKING
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
|
||||||
|
|
||||||
class _CONDGuideEntries(comfy.conds.CONDConstant):
|
|
||||||
"""CONDConstant subclass that safely compares guide_attention_entries.
|
|
||||||
|
|
||||||
guide_attention_entries may contain ``pixel_mask`` tensors. The default
|
|
||||||
``CONDConstant.can_concat`` uses ``!=`` which triggers a ``ValueError``
|
|
||||||
on tensors. This subclass performs a structural comparison instead.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def can_concat(self, other):
|
|
||||||
if not isinstance(other, _CONDGuideEntries):
|
|
||||||
return False
|
|
||||||
a, b = self.cond, other.cond
|
|
||||||
if len(a) != len(b):
|
|
||||||
return False
|
|
||||||
for ea, eb in zip(a, b):
|
|
||||||
if ea["pre_filter_count"] != eb["pre_filter_count"]:
|
|
||||||
return False
|
|
||||||
if ea["strength"] != eb["strength"]:
|
|
||||||
return False
|
|
||||||
if ea.get("latent_shape") != eb.get("latent_shape"):
|
|
||||||
return False
|
|
||||||
a_has = ea.get("pixel_mask") is not None
|
|
||||||
b_has = eb.get("pixel_mask") is not None
|
|
||||||
if a_has != b_has:
|
|
||||||
return False
|
|
||||||
if a_has:
|
|
||||||
pm_a, pm_b = ea["pixel_mask"], eb["pixel_mask"]
|
|
||||||
if pm_a is not pm_b:
|
|
||||||
if (pm_a.shape != pm_b.shape
|
|
||||||
or pm_a.device != pm_b.device
|
|
||||||
or pm_a.dtype != pm_b.dtype
|
|
||||||
or not torch.equal(pm_a, pm_b)):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
EPS = 1
|
EPS = 1
|
||||||
V_PREDICTION = 2
|
V_PREDICTION = 2
|
||||||
@ -1012,7 +976,7 @@ class LTXV(BaseModel):
|
|||||||
|
|
||||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||||
if guide_attention_entries is not None:
|
if guide_attention_entries is not None:
|
||||||
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
|
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -1068,7 +1032,7 @@ class LTXAV(BaseModel):
|
|||||||
|
|
||||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||||
if guide_attention_entries is not None:
|
if guide_attention_entries is not None:
|
||||||
out['guide_attention_entries'] = _CONDGuideEntries(guide_attention_entries)
|
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user