mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-12 01:07:30 +08:00
Simplify the logic by using repackaged weights.
This commit is contained in:
parent
dfcd0ca160
commit
adae3b4a48
File diff suppressed because one or more lines are too long
@ -279,36 +279,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}x_embedder.weight'.format(key_prefix) in state_dict_keys and '{}transformer_blocks.0.attn.to_q.weight'.format(key_prefix) in state_dict_keys and '{}single_transformer_blocks.0.attn.to_q.weight'.format(key_prefix) in state_dict_keys and '{}context_embedder.weight'.format(key_prefix) in state_dict_keys and '{}time_embed.timestep_embedder.linear_1.weight'.format(key_prefix) in state_dict_keys: #LongCat-Image (diffusers format, Flux variant)
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["qkv_bias"] = True
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
|
||||
w = state_dict['{}x_embedder.weight'.format(key_prefix)]
|
||||
dit_config["hidden_size"] = w.shape[0]
|
||||
dit_config["in_channels"] = w.shape[1] // 4
|
||||
dit_config["out_channels"] = dit_config["in_channels"]
|
||||
dit_config["patch_size"] = 2
|
||||
|
||||
ctx_key = '{}context_embedder.weight'.format(key_prefix)
|
||||
if ctx_key in state_dict_keys:
|
||||
dit_config["context_in_dim"] = state_dict[ctx_key].shape[1]
|
||||
else:
|
||||
dit_config["context_in_dim"] = 3584
|
||||
|
||||
dit_config["vec_in_dim"] = None
|
||||
dit_config["guidance_embed"] = False
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"])
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
if dit_config.get("context_in_dim") == 3584 and dit_config["vec_in_dim"] is None: # LongCat-Image
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
|
||||
return dit_config
|
||||
|
||||
|
||||
@ -1687,13 +1687,6 @@ class LongCatImage(supported_models_base.BASE):
|
||||
"txt_ids_dims": [1, 2],
|
||||
}
|
||||
|
||||
required_keys = {
|
||||
"x_embedder.weight": None,
|
||||
"context_embedder.weight": None,
|
||||
"single_transformer_blocks.10.attn.to_q.weight": None,
|
||||
"time_embed.timestep_embedder.linear_1.weight": None,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
}
|
||||
|
||||
@ -1707,115 +1700,6 @@ class LongCatImage(supported_models_base.BASE):
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
double_q, double_k, double_v = {}, {}, {}
|
||||
double_tq, double_tk, double_tv = {}, {}, {}
|
||||
single_q, single_k, single_v, single_mlp = {}, {}, {}, {}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("transformer_blocks."):
|
||||
idx = k.split(".")[1]
|
||||
rest = ".".join(k.split(".")[2:])
|
||||
prefix = "double_blocks.{}.".format(idx)
|
||||
|
||||
if rest.startswith("norm1.linear."):
|
||||
out_sd[prefix + "img_mod.lin." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("norm1_context.linear."):
|
||||
out_sd[prefix + "txt_mod.lin." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("attn.to_q."):
|
||||
double_q[idx + "." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("attn.to_k."):
|
||||
double_k[idx + "." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("attn.to_v."):
|
||||
double_v[idx + "." + rest.split(".")[-1]] = v
|
||||
elif rest == "attn.norm_q.weight":
|
||||
out_sd[prefix + "img_attn.norm.query_norm.weight"] = v
|
||||
elif rest == "attn.norm_k.weight":
|
||||
out_sd[prefix + "img_attn.norm.key_norm.weight"] = v
|
||||
elif rest.startswith("attn.to_out.0."):
|
||||
out_sd[prefix + "img_attn.proj." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("attn.add_q_proj."):
|
||||
double_tq[idx + "." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("attn.add_k_proj."):
|
||||
double_tk[idx + "." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("attn.add_v_proj."):
|
||||
double_tv[idx + "." + rest.split(".")[-1]] = v
|
||||
elif rest == "attn.norm_added_q.weight":
|
||||
out_sd[prefix + "txt_attn.norm.query_norm.weight"] = v
|
||||
elif rest == "attn.norm_added_k.weight":
|
||||
out_sd[prefix + "txt_attn.norm.key_norm.weight"] = v
|
||||
elif rest.startswith("attn.to_add_out."):
|
||||
out_sd[prefix + "txt_attn.proj." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("ff.net.0.proj."):
|
||||
out_sd[prefix + "img_mlp.0." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("ff.net.2."):
|
||||
out_sd[prefix + "img_mlp.2." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("ff_context.net.0.proj."):
|
||||
out_sd[prefix + "txt_mlp.0." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("ff_context.net.2."):
|
||||
out_sd[prefix + "txt_mlp.2." + rest.split(".")[-1]] = v
|
||||
else:
|
||||
out_sd["double_blocks.{}.{}".format(idx, rest)] = v
|
||||
|
||||
elif k.startswith("single_transformer_blocks."):
|
||||
idx = k.split(".")[1]
|
||||
rest = ".".join(k.split(".")[2:])
|
||||
prefix = "single_blocks.{}.".format(idx)
|
||||
|
||||
if rest.startswith("norm.linear."):
|
||||
out_sd[prefix + "modulation.lin." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("attn.to_q."):
|
||||
single_q[idx + "." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("attn.to_k."):
|
||||
single_k[idx + "." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("attn.to_v."):
|
||||
single_v[idx + "." + rest.split(".")[-1]] = v
|
||||
elif rest == "attn.norm_q.weight":
|
||||
out_sd[prefix + "norm.query_norm.weight"] = v
|
||||
elif rest == "attn.norm_k.weight":
|
||||
out_sd[prefix + "norm.key_norm.weight"] = v
|
||||
elif rest.startswith("proj_mlp."):
|
||||
single_mlp[idx + "." + rest.split(".")[-1]] = v
|
||||
elif rest.startswith("proj_out."):
|
||||
out_sd[prefix + "linear2." + rest.split(".")[-1]] = v
|
||||
else:
|
||||
out_sd["single_blocks.{}.{}".format(idx, rest)] = v
|
||||
|
||||
elif k == "x_embedder.weight" or k == "x_embedder.bias":
|
||||
out_sd["img_in." + k.split(".")[-1]] = v
|
||||
elif k == "context_embedder.weight" or k == "context_embedder.bias":
|
||||
out_sd["txt_in." + k.split(".")[-1]] = v
|
||||
elif k.startswith("time_embed.timestep_embedder.linear_1."):
|
||||
out_sd["time_in.in_layer." + k.split(".")[-1]] = v
|
||||
elif k.startswith("time_embed.timestep_embedder.linear_2."):
|
||||
out_sd["time_in.out_layer." + k.split(".")[-1]] = v
|
||||
elif k.startswith("norm_out.linear."):
|
||||
# HF AdaLayerNormContinuous stores [scale | shift] but ComfyUI
|
||||
# LastLayer expects [shift | scale], so swap the two halves.
|
||||
half = v.shape[0] // 2
|
||||
v = torch.cat([v[half:], v[:half]], dim=0)
|
||||
out_sd["final_layer.adaLN_modulation.1." + k.split(".")[-1]] = v
|
||||
elif k == "proj_out.weight" or k == "proj_out.bias":
|
||||
out_sd["final_layer.linear." + k.split(".")[-1]] = v
|
||||
else:
|
||||
out_sd[k] = v
|
||||
|
||||
for suffix in ["weight", "bias"]:
|
||||
for idx in sorted(set(x.split(".")[0] for x in double_q)):
|
||||
qk = idx + "." + suffix
|
||||
if qk in double_q and qk in double_k and qk in double_v:
|
||||
out_sd["double_blocks.{}.img_attn.qkv.{}".format(idx, suffix)] = torch.cat([double_q[qk], double_k[qk], double_v[qk]], dim=0)
|
||||
if qk in double_tq and qk in double_tk and qk in double_tv:
|
||||
out_sd["double_blocks.{}.txt_attn.qkv.{}".format(idx, suffix)] = torch.cat([double_tq[qk], double_tk[qk], double_tv[qk]], dim=0)
|
||||
|
||||
for idx in sorted(set(x.split(".")[0] for x in single_q)):
|
||||
qk = idx + "." + suffix
|
||||
if qk in single_q and qk in single_k and qk in single_v and qk in single_mlp:
|
||||
out_sd["single_blocks.{}.linear1.{}".format(idx, suffix)] = torch.cat([single_q[qk], single_k[qk], single_v[qk], single_mlp[qk]], dim=0)
|
||||
|
||||
return out_sd
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.LongCatImage(self, device=device)
|
||||
return out
|
||||
|
||||
@ -5,42 +5,40 @@ from comfy.model_detection import detect_unet_config, model_config_from_unet_con
|
||||
import comfy.supported_models
|
||||
|
||||
|
||||
def _make_longcat_diffusers_sd():
|
||||
"""Minimal Diffusers-format state dict that triggers the LongCat-Image detection path."""
|
||||
def _make_longcat_comfyui_sd():
|
||||
"""Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights."""
|
||||
sd = {}
|
||||
H = 3072 # hidden_size (matches real LongCat-Image)
|
||||
H = 3072
|
||||
C_IN = 16
|
||||
C_CTX = 3584 # context_in_dim that distinguishes LongCat from standard Flux (4096)
|
||||
C_CTX = 3584
|
||||
|
||||
sd["x_embedder.weight"] = torch.empty(H, C_IN * 4)
|
||||
sd["x_embedder.bias"] = torch.empty(H)
|
||||
sd["context_embedder.weight"] = torch.empty(H, C_CTX)
|
||||
sd["context_embedder.bias"] = torch.empty(H)
|
||||
sd["img_in.weight"] = torch.empty(H, C_IN * 4)
|
||||
sd["img_in.bias"] = torch.empty(H)
|
||||
sd["txt_in.weight"] = torch.empty(H, C_CTX)
|
||||
sd["txt_in.bias"] = torch.empty(H)
|
||||
|
||||
sd["time_embed.timestep_embedder.linear_1.weight"] = torch.empty(H, 256)
|
||||
sd["time_embed.timestep_embedder.linear_1.bias"] = torch.empty(H)
|
||||
sd["time_embed.timestep_embedder.linear_2.weight"] = torch.empty(H, H)
|
||||
sd["time_embed.timestep_embedder.linear_2.bias"] = torch.empty(H)
|
||||
sd["time_in.in_layer.weight"] = torch.empty(H, 256)
|
||||
sd["time_in.in_layer.bias"] = torch.empty(H)
|
||||
sd["time_in.out_layer.weight"] = torch.empty(H, H)
|
||||
sd["time_in.out_layer.bias"] = torch.empty(H)
|
||||
|
||||
sd["norm_out.linear.weight"] = torch.empty(2 * H, H)
|
||||
sd["norm_out.linear.bias"] = torch.empty(2 * H)
|
||||
sd["proj_out.weight"] = torch.empty(C_IN * 4, H)
|
||||
sd["proj_out.bias"] = torch.empty(C_IN * 4)
|
||||
sd["final_layer.adaLN_modulation.1.weight"] = torch.empty(2 * H, H)
|
||||
sd["final_layer.adaLN_modulation.1.bias"] = torch.empty(2 * H)
|
||||
sd["final_layer.linear.weight"] = torch.empty(C_IN * 4, H)
|
||||
sd["final_layer.linear.bias"] = torch.empty(C_IN * 4)
|
||||
|
||||
# Need enough transformer_blocks and single_transformer_blocks for count_blocks
|
||||
# and for the required_keys check (single_transformer_blocks.10.*)
|
||||
for i in range(19):
|
||||
sd[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.empty(H, H)
|
||||
sd[f"transformer_blocks.{i}.norm1.linear.weight"] = torch.empty(H)
|
||||
sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128)
|
||||
sd[f"double_blocks.{i}.img_attn.qkv.weight"] = torch.empty(3 * H, H)
|
||||
sd[f"double_blocks.{i}.img_mod.lin.weight"] = torch.empty(H, H)
|
||||
for i in range(38):
|
||||
sd[f"single_transformer_blocks.{i}.attn.to_q.weight"] = torch.empty(H, H)
|
||||
sd[f"single_transformer_blocks.{i}.norm.linear.weight"] = torch.empty(H)
|
||||
sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H)
|
||||
|
||||
return sd
|
||||
|
||||
|
||||
def _make_flux_schnell_comfyui_sd():
|
||||
"""Minimal ComfyUI-format state dict that triggers the standard Flux detection path."""
|
||||
"""Minimal ComfyUI-format state dict for standard Flux Schnell."""
|
||||
sd = {}
|
||||
H = 3072
|
||||
C_IN = 16
|
||||
@ -67,18 +65,16 @@ class TestModelDetectionSpecificity:
|
||||
|
||||
def test_longcat_wins_regardless_of_list_order(self):
|
||||
"""Specificity logic must pick LongCatImage even when FluxSchnell appears first."""
|
||||
sd = _make_longcat_diffusers_sd()
|
||||
sd = _make_longcat_comfyui_sd()
|
||||
unet_config = detect_unet_config(sd, "")
|
||||
|
||||
longcat_cls = comfy.supported_models.LongCatImage
|
||||
schnell_cls = comfy.supported_models.FluxSchnell
|
||||
|
||||
# Order A: FluxSchnell before LongCatImage
|
||||
order_a = [schnell_cls, longcat_cls]
|
||||
# Order B: LongCatImage before FluxSchnell
|
||||
order_b = [longcat_cls, schnell_cls]
|
||||
|
||||
for label, order in [("schnell-first", order_a), ("longcat-first", order_b)]:
|
||||
for label, order in [
|
||||
("schnell-first", [schnell_cls, longcat_cls]),
|
||||
("longcat-first", [longcat_cls, schnell_cls]),
|
||||
]:
|
||||
with patch.object(comfy.supported_models, "models", order):
|
||||
result = model_config_from_unet_config(unet_config, sd)
|
||||
assert result is not None, f"No match with order {label}"
|
||||
@ -86,29 +82,31 @@ class TestModelDetectionSpecificity:
|
||||
f"Expected LongCatImage with order {label}, got {type(result).__name__}"
|
||||
)
|
||||
|
||||
def test_longcat_diffusers_detected_as_longcat(self):
|
||||
sd = _make_longcat_diffusers_sd()
|
||||
def test_longcat_comfyui_detected_as_longcat(self):
|
||||
sd = _make_longcat_comfyui_sd()
|
||||
unet_config = detect_unet_config(sd, "")
|
||||
assert unet_config is not None
|
||||
assert unet_config["image_model"] == "flux"
|
||||
assert unet_config["context_in_dim"] == 3584
|
||||
assert unet_config["vec_in_dim"] is None
|
||||
assert unet_config["guidance_embed"] is False
|
||||
assert unet_config["txt_ids_dims"] == [1, 2]
|
||||
|
||||
model_config = model_config_from_unet_config(unet_config, sd)
|
||||
assert model_config is not None
|
||||
assert type(model_config).__name__ == "LongCatImage"
|
||||
|
||||
def test_longcat_process_unet_state_dict_converts_keys(self):
|
||||
sd = _make_longcat_diffusers_sd()
|
||||
def test_longcat_comfyui_keys_pass_through_unchanged(self):
|
||||
"""Pre-converted weights should not be transformed by process_unet_state_dict."""
|
||||
sd = _make_longcat_comfyui_sd()
|
||||
unet_config = detect_unet_config(sd, "")
|
||||
model_config = model_config_from_unet_config(unet_config, sd)
|
||||
|
||||
converted = model_config.process_unet_state_dict(dict(sd))
|
||||
assert "img_in.weight" in converted
|
||||
assert "img_in.bias" in converted
|
||||
assert "txt_in.weight" in converted
|
||||
assert "x_embedder.weight" not in converted
|
||||
assert "context_embedder.weight" not in converted
|
||||
processed = model_config.process_unet_state_dict(dict(sd))
|
||||
assert "img_in.weight" in processed
|
||||
assert "txt_in.weight" in processed
|
||||
assert "time_in.in_layer.weight" in processed
|
||||
assert "final_layer.linear.weight" in processed
|
||||
|
||||
def test_flux_schnell_comfyui_detected_as_flux_schnell(self):
|
||||
sd = _make_flux_schnell_comfyui_sd()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user