mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-20 22:17:31 +08:00
Merge 827093ed1d into 6bcd8b96ab
This commit is contained in:
commit
7130b8a3ce
@ -905,9 +905,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
def model_config_from_unet_config(unet_config, state_dict=None):
|
def model_config_from_unet_config(unet_config, state_dict=None):
|
||||||
|
best_match = None
|
||||||
|
best_score = -1
|
||||||
for model_config in comfy.supported_models.models:
|
for model_config in comfy.supported_models.models:
|
||||||
if model_config.matches(unet_config, state_dict):
|
score = model_config.match_score(unet_config, state_dict)
|
||||||
return model_config(unet_config)
|
if score > best_score:
|
||||||
|
best_score = score
|
||||||
|
best_match = model_config
|
||||||
|
|
||||||
|
if best_match is not None:
|
||||||
|
return best_match(unet_config)
|
||||||
|
|
||||||
logging.error("no match {}".format(unet_config))
|
logging.error("no match {}".format(unet_config))
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -54,15 +54,31 @@ class BASE:
|
|||||||
optimizations = {"fp8": False}
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config, state_dict=None):
|
def match_score(s, unet_config, state_dict=None):
|
||||||
|
"""Return a non-negative specificity score if this model matches the given
|
||||||
|
``unet_config``/``state_dict``, otherwise ``-1``.
|
||||||
|
|
||||||
|
The score is the total number of ``unet_config`` keys (plus ``required_keys``
|
||||||
|
when ``state_dict`` is provided) that this model declares and that match the
|
||||||
|
input. Higher scores indicate a more specific match, allowing
|
||||||
|
:func:`model_config_from_unet_config` to prefer the most specific model when
|
||||||
|
several configs would otherwise match by subset.
|
||||||
|
"""
|
||||||
|
score = 0
|
||||||
for k in s.unet_config:
|
for k in s.unet_config:
|
||||||
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||||
return False
|
return -1
|
||||||
|
score += 1
|
||||||
if state_dict is not None:
|
if state_dict is not None:
|
||||||
for k in s.required_keys:
|
for k in s.required_keys:
|
||||||
if k not in state_dict:
|
if k not in state_dict:
|
||||||
return False
|
return -1
|
||||||
return True
|
score += 1
|
||||||
|
return score
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def matches(s, unet_config, state_dict=None):
|
||||||
|
return s.match_score(unet_config, state_dict) >= 0
|
||||||
|
|
||||||
def model_type(self, state_dict, prefix=""):
|
def model_type(self, state_dict, prefix=""):
|
||||||
return model_base.ModelType.EPS
|
return model_base.ModelType.EPS
|
||||||
|
|||||||
@ -60,18 +60,38 @@ def _make_flux_schnell_comfyui_sd():
|
|||||||
|
|
||||||
|
|
||||||
class TestModelDetection:
|
class TestModelDetection:
|
||||||
"""Verify that first-match model detection selects the correct model
|
"""Verify that model detection selects the most specific model regardless of
|
||||||
based on list ordering and unet_config specificity."""
|
the ordering of entries in ``comfy.supported_models.models``."""
|
||||||
|
|
||||||
def test_longcat_before_schnell_in_models_list(self):
|
def test_longcat_detection_is_order_independent(self):
|
||||||
"""LongCatImage must appear before FluxSchnell in the models list."""
|
"""Detection must pick LongCatImage over FluxSchnell regardless of
|
||||||
models = comfy.supported_models.models
|
their relative order in the models list, because LongCatImage has a
|
||||||
longcat_idx = next(i for i, m in enumerate(models) if m.__name__ == "LongCatImage")
|
strictly more specific ``unet_config``."""
|
||||||
schnell_idx = next(i for i, m in enumerate(models) if m.__name__ == "FluxSchnell")
|
original_models = comfy.supported_models.models
|
||||||
assert longcat_idx < schnell_idx, (
|
sd = _make_longcat_comfyui_sd()
|
||||||
f"LongCatImage (index {longcat_idx}) must come before "
|
unet_config = detect_unet_config(sd, "")
|
||||||
f"FluxSchnell (index {schnell_idx}) in the models list"
|
|
||||||
)
|
try:
|
||||||
|
for ordering in ("longcat_first", "schnell_first"):
|
||||||
|
models = list(original_models)
|
||||||
|
longcat = next(m for m in models if m.__name__ == "LongCatImage")
|
||||||
|
schnell = next(m for m in models if m.__name__ == "FluxSchnell")
|
||||||
|
models.remove(longcat)
|
||||||
|
models.remove(schnell)
|
||||||
|
if ordering == "longcat_first":
|
||||||
|
models.extend([longcat, schnell])
|
||||||
|
else:
|
||||||
|
models.extend([schnell, longcat])
|
||||||
|
comfy.supported_models.models = models
|
||||||
|
|
||||||
|
model_config = model_config_from_unet_config(unet_config, sd)
|
||||||
|
assert model_config is not None
|
||||||
|
assert type(model_config).__name__ == "LongCatImage", (
|
||||||
|
f"Expected LongCatImage with ordering={ordering}, "
|
||||||
|
f"got {type(model_config).__name__}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
comfy.supported_models.models = original_models
|
||||||
|
|
||||||
def test_longcat_comfyui_detected_as_longcat(self):
|
def test_longcat_comfyui_detected_as_longcat(self):
|
||||||
sd = _make_longcat_comfyui_sd()
|
sd = _make_longcat_comfyui_sd()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user