mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-18 04:57:26 +08:00
Merge 827093ed1d into 160b95f75c
This commit is contained in:
commit
058fb8508b
@ -905,9 +905,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
return unet_config
|
||||
|
||||
def model_config_from_unet_config(unet_config, state_dict=None):
|
||||
best_match = None
|
||||
best_score = -1
|
||||
for model_config in comfy.supported_models.models:
|
||||
if model_config.matches(unet_config, state_dict):
|
||||
return model_config(unet_config)
|
||||
score = model_config.match_score(unet_config, state_dict)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_match = model_config
|
||||
|
||||
if best_match is not None:
|
||||
return best_match(unet_config)
|
||||
|
||||
logging.error("no match {}".format(unet_config))
|
||||
return None
|
||||
|
||||
@ -54,15 +54,31 @@ class BASE:
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
def match_score(s, unet_config, state_dict=None):
|
||||
"""Return a non-negative specificity score if this model matches the given
|
||||
``unet_config``/``state_dict``, otherwise ``-1``.
|
||||
|
||||
The score is the total number of ``unet_config`` keys (plus ``required_keys``
|
||||
when ``state_dict`` is provided) that this model declares and that match the
|
||||
input. Higher scores indicate a more specific match, allowing
|
||||
:func:`model_config_from_unet_config` to prefer the most specific model when
|
||||
several configs would otherwise match by subset.
|
||||
"""
|
||||
score = 0
|
||||
for k in s.unet_config:
|
||||
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||
return False
|
||||
return -1
|
||||
score += 1
|
||||
if state_dict is not None:
|
||||
for k in s.required_keys:
|
||||
if k not in state_dict:
|
||||
return False
|
||||
return True
|
||||
return -1
|
||||
score += 1
|
||||
return score
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
return s.match_score(unet_config, state_dict) >= 0
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
return model_base.ModelType.EPS
|
||||
|
||||
@ -60,18 +60,38 @@ def _make_flux_schnell_comfyui_sd():
|
||||
|
||||
|
||||
class TestModelDetection:
|
||||
"""Verify that first-match model detection selects the correct model
|
||||
based on list ordering and unet_config specificity."""
|
||||
"""Verify that model detection selects the most specific model regardless of
|
||||
the ordering of entries in ``comfy.supported_models.models``."""
|
||||
|
||||
def test_longcat_before_schnell_in_models_list(self):
|
||||
"""LongCatImage must appear before FluxSchnell in the models list."""
|
||||
models = comfy.supported_models.models
|
||||
longcat_idx = next(i for i, m in enumerate(models) if m.__name__ == "LongCatImage")
|
||||
schnell_idx = next(i for i, m in enumerate(models) if m.__name__ == "FluxSchnell")
|
||||
assert longcat_idx < schnell_idx, (
|
||||
f"LongCatImage (index {longcat_idx}) must come before "
|
||||
f"FluxSchnell (index {schnell_idx}) in the models list"
|
||||
)
|
||||
def test_longcat_detection_is_order_independent(self):
|
||||
"""Detection must pick LongCatImage over FluxSchnell regardless of
|
||||
their relative order in the models list, because LongCatImage has a
|
||||
strictly more specific ``unet_config``."""
|
||||
original_models = comfy.supported_models.models
|
||||
sd = _make_longcat_comfyui_sd()
|
||||
unet_config = detect_unet_config(sd, "")
|
||||
|
||||
try:
|
||||
for ordering in ("longcat_first", "schnell_first"):
|
||||
models = list(original_models)
|
||||
longcat = next(m for m in models if m.__name__ == "LongCatImage")
|
||||
schnell = next(m for m in models if m.__name__ == "FluxSchnell")
|
||||
models.remove(longcat)
|
||||
models.remove(schnell)
|
||||
if ordering == "longcat_first":
|
||||
models.extend([longcat, schnell])
|
||||
else:
|
||||
models.extend([schnell, longcat])
|
||||
comfy.supported_models.models = models
|
||||
|
||||
model_config = model_config_from_unet_config(unet_config, sd)
|
||||
assert model_config is not None
|
||||
assert type(model_config).__name__ == "LongCatImage", (
|
||||
f"Expected LongCatImage with ordering={ordering}, "
|
||||
f"got {type(model_config).__name__}"
|
||||
)
|
||||
finally:
|
||||
comfy.supported_models.models = original_models
|
||||
|
||||
def test_longcat_comfyui_detected_as_longcat(self):
|
||||
sd = _make_longcat_comfyui_sd()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user