Change auto model selection to best match from first match CORE-60.

This commit is contained in:
Talmaj Marinc 2026-05-01 12:06:06 +02:00
parent cf9cbec596
commit 827093ed1d
3 changed files with 60 additions and 17 deletions

View File

@ -905,9 +905,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return unet_config
def model_config_from_unet_config(unet_config, state_dict=None):
best_match = None
best_score = -1
for model_config in comfy.supported_models.models:
if model_config.matches(unet_config, state_dict):
return model_config(unet_config)
score = model_config.match_score(unet_config, state_dict)
if score > best_score:
best_score = score
best_match = model_config
if best_match is not None:
return best_match(unet_config)
logging.error("no match {}".format(unet_config))
return None

View File

@ -54,15 +54,31 @@ class BASE:
optimizations = {"fp8": False}
@classmethod
def matches(s, unet_config, state_dict=None):
def match_score(s, unet_config, state_dict=None):
"""Return a non-negative specificity score if this model matches the given
``unet_config``/``state_dict``, otherwise ``-1``.
The score is the total number of ``unet_config`` keys (plus ``required_keys``
when ``state_dict`` is provided) that this model declares and that match the
input. Higher scores indicate a more specific match, allowing
:func:`model_config_from_unet_config` to prefer the most specific model when
several configs would otherwise match by subset.
"""
score = 0
for k in s.unet_config:
if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False
return -1
score += 1
if state_dict is not None:
for k in s.required_keys:
if k not in state_dict:
return False
return True
return -1
score += 1
return score
@classmethod
def matches(s, unet_config, state_dict=None):
return s.match_score(unet_config, state_dict) >= 0
def model_type(self, state_dict, prefix=""):
return model_base.ModelType.EPS

View File

@ -60,18 +60,38 @@ def _make_flux_schnell_comfyui_sd():
class TestModelDetection:
"""Verify that first-match model detection selects the correct model
based on list ordering and unet_config specificity."""
"""Verify that model detection selects the most specific model regardless of
the ordering of entries in ``comfy.supported_models.models``."""
def test_longcat_before_schnell_in_models_list(self):
"""LongCatImage must appear before FluxSchnell in the models list."""
models = comfy.supported_models.models
longcat_idx = next(i for i, m in enumerate(models) if m.__name__ == "LongCatImage")
schnell_idx = next(i for i, m in enumerate(models) if m.__name__ == "FluxSchnell")
assert longcat_idx < schnell_idx, (
f"LongCatImage (index {longcat_idx}) must come before "
f"FluxSchnell (index {schnell_idx}) in the models list"
)
def test_longcat_detection_is_order_independent(self):
"""Detection must pick LongCatImage over FluxSchnell regardless of
their relative order in the models list, because LongCatImage has a
strictly more specific ``unet_config``."""
original_models = comfy.supported_models.models
sd = _make_longcat_comfyui_sd()
unet_config = detect_unet_config(sd, "")
try:
for ordering in ("longcat_first", "schnell_first"):
models = list(original_models)
longcat = next(m for m in models if m.__name__ == "LongCatImage")
schnell = next(m for m in models if m.__name__ == "FluxSchnell")
models.remove(longcat)
models.remove(schnell)
if ordering == "longcat_first":
models.extend([longcat, schnell])
else:
models.extend([schnell, longcat])
comfy.supported_models.models = models
model_config = model_config_from_unet_config(unet_config, sd)
assert model_config is not None
assert type(model_config).__name__ == "LongCatImage", (
f"Expected LongCatImage with ordering={ordering}, "
f"got {type(model_config).__name__}"
)
finally:
comfy.supported_models.models = original_models
def test_longcat_comfyui_detected_as_longcat(self):
sd = _make_longcat_comfyui_sd()