From 827093ed1da9b5b52c15f08cd64f36a2843615a1 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 1 May 2026 12:06:06 +0200 Subject: [PATCH] Change auto model selection to best match from first match CORE-60. --- comfy/model_detection.py | 11 ++++- comfy/supported_models_base.py | 24 +++++++++-- tests-unit/comfy_test/model_detection_test.py | 42 ++++++++++++++----- 3 files changed, 60 insertions(+), 17 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index d9b67dcdf..d5e82b013 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -905,9 +905,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return unet_config def model_config_from_unet_config(unet_config, state_dict=None): + best_match = None + best_score = -1 for model_config in comfy.supported_models.models: - if model_config.matches(unet_config, state_dict): - return model_config(unet_config) + score = model_config.match_score(unet_config, state_dict) + if score > best_score: + best_score = score + best_match = model_config + + if best_match is not None: + return best_match(unet_config) logging.error("no match {}".format(unet_config)) return None diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 0e7a829ba..653d6a0f9 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -54,15 +54,31 @@ class BASE: optimizations = {"fp8": False} @classmethod - def matches(s, unet_config, state_dict=None): + def match_score(s, unet_config, state_dict=None): + """Return a non-negative specificity score if this model matches the given + ``unet_config``/``state_dict``, otherwise ``-1``. + + The score is the total number of ``unet_config`` keys (plus ``required_keys`` + when ``state_dict`` is provided) that this model declares and that match the + input. Higher scores indicate a more specific match, allowing + :func:`model_config_from_unet_config` to prefer the most specific model when + several configs would otherwise match by subset. + """ + score = 0 for k in s.unet_config: if k not in unet_config or s.unet_config[k] != unet_config[k]: - return False + return -1 + score += 1 if state_dict is not None: for k in s.required_keys: if k not in state_dict: - return False - return True + return -1 + score += 1 + return score + + @classmethod + def matches(s, unet_config, state_dict=None): + return s.match_score(unet_config, state_dict) >= 0 def model_type(self, state_dict, prefix=""): return model_base.ModelType.EPS diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index 2551a417b..dfbe68620 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -60,18 +60,38 @@ def _make_flux_schnell_comfyui_sd(): class TestModelDetection: - """Verify that first-match model detection selects the correct model - based on list ordering and unet_config specificity.""" + """Verify that model detection selects the most specific model regardless of + the ordering of entries in ``comfy.supported_models.models``.""" - def test_longcat_before_schnell_in_models_list(self): - """LongCatImage must appear before FluxSchnell in the models list.""" - models = comfy.supported_models.models - longcat_idx = next(i for i, m in enumerate(models) if m.__name__ == "LongCatImage") - schnell_idx = next(i for i, m in enumerate(models) if m.__name__ == "FluxSchnell") - assert longcat_idx < schnell_idx, ( - f"LongCatImage (index {longcat_idx}) must come before " - f"FluxSchnell (index {schnell_idx}) in the models list" - ) + def test_longcat_detection_is_order_independent(self): + """Detection must pick LongCatImage over FluxSchnell regardless of + their relative order in the models list, because LongCatImage has a + strictly more specific ``unet_config``.""" + original_models = comfy.supported_models.models + sd = _make_longcat_comfyui_sd() + unet_config = detect_unet_config(sd, "") + + try: + for ordering in ("longcat_first", "schnell_first"): + models = list(original_models) + longcat = next(m for m in models if m.__name__ == "LongCatImage") + schnell = next(m for m in models if m.__name__ == "FluxSchnell") + models.remove(longcat) + models.remove(schnell) + if ordering == "longcat_first": + models.extend([longcat, schnell]) + else: + models.extend([schnell, longcat]) + comfy.supported_models.models = models + + model_config = model_config_from_unet_config(unet_config, sd) + assert model_config is not None + assert type(model_config).__name__ == "LongCatImage", ( + f"Expected LongCatImage with ordering={ordering}, " + f"got {type(model_config).__name__}" + ) + finally: + comfy.supported_models.models = original_models def test_longcat_comfyui_detected_as_longcat(self): sd = _make_longcat_comfyui_sd()