Simplify the logic by using repackaged weights.

This commit is contained in:
Talmaj Marinc 2026-02-24 23:13:32 +01:00
parent dfcd0ca160
commit adae3b4a48
4 changed files with 41 additions and 187 deletions

File diff suppressed because one or more lines are too long

View File

@ -279,36 +279,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]
return dit_config
if '{}x_embedder.weight'.format(key_prefix) in state_dict_keys and '{}transformer_blocks.0.attn.to_q.weight'.format(key_prefix) in state_dict_keys and '{}single_transformer_blocks.0.attn.to_q.weight'.format(key_prefix) in state_dict_keys and '{}context_embedder.weight'.format(key_prefix) in state_dict_keys and '{}time_embed.timestep_embedder.linear_1.weight'.format(key_prefix) in state_dict_keys: #LongCat-Image (diffusers format, Flux variant)
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
dit_config["txt_ids_dims"] = [1, 2]
w = state_dict['{}x_embedder.weight'.format(key_prefix)]
dit_config["hidden_size"] = w.shape[0]
dit_config["in_channels"] = w.shape[1] // 4
dit_config["out_channels"] = dit_config["in_channels"]
dit_config["patch_size"] = 2
ctx_key = '{}context_embedder.weight'.format(key_prefix)
if ctx_key in state_dict_keys:
dit_config["context_in_dim"] = state_dict[ctx_key].shape[1]
else:
dit_config["context_in_dim"] = 3584
dit_config["vec_in_dim"] = None
dit_config["guidance_embed"] = False
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"])
dit_config["depth"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_transformer_blocks.'.format(key_prefix) + '{}.')
if dit_config.get("context_in_dim") == 3584 and dit_config["vec_in_dim"] is None: # LongCat-Image
dit_config["txt_ids_dims"] = [1, 2]
return dit_config

View File

@ -1687,13 +1687,6 @@ class LongCatImage(supported_models_base.BASE):
"txt_ids_dims": [1, 2],
}
required_keys = {
"x_embedder.weight": None,
"context_embedder.weight": None,
"single_transformer_blocks.10.attn.to_q.weight": None,
"time_embed.timestep_embedder.linear_1.weight": None,
}
sampling_settings = {
}
@ -1707,115 +1700,6 @@ class LongCatImage(supported_models_base.BASE):
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def process_unet_state_dict(self, state_dict):
out_sd = {}
double_q, double_k, double_v = {}, {}, {}
double_tq, double_tk, double_tv = {}, {}, {}
single_q, single_k, single_v, single_mlp = {}, {}, {}, {}
for k, v in state_dict.items():
if k.startswith("transformer_blocks."):
idx = k.split(".")[1]
rest = ".".join(k.split(".")[2:])
prefix = "double_blocks.{}.".format(idx)
if rest.startswith("norm1.linear."):
out_sd[prefix + "img_mod.lin." + rest.split(".")[-1]] = v
elif rest.startswith("norm1_context.linear."):
out_sd[prefix + "txt_mod.lin." + rest.split(".")[-1]] = v
elif rest.startswith("attn.to_q."):
double_q[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("attn.to_k."):
double_k[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("attn.to_v."):
double_v[idx + "." + rest.split(".")[-1]] = v
elif rest == "attn.norm_q.weight":
out_sd[prefix + "img_attn.norm.query_norm.weight"] = v
elif rest == "attn.norm_k.weight":
out_sd[prefix + "img_attn.norm.key_norm.weight"] = v
elif rest.startswith("attn.to_out.0."):
out_sd[prefix + "img_attn.proj." + rest.split(".")[-1]] = v
elif rest.startswith("attn.add_q_proj."):
double_tq[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("attn.add_k_proj."):
double_tk[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("attn.add_v_proj."):
double_tv[idx + "." + rest.split(".")[-1]] = v
elif rest == "attn.norm_added_q.weight":
out_sd[prefix + "txt_attn.norm.query_norm.weight"] = v
elif rest == "attn.norm_added_k.weight":
out_sd[prefix + "txt_attn.norm.key_norm.weight"] = v
elif rest.startswith("attn.to_add_out."):
out_sd[prefix + "txt_attn.proj." + rest.split(".")[-1]] = v
elif rest.startswith("ff.net.0.proj."):
out_sd[prefix + "img_mlp.0." + rest.split(".")[-1]] = v
elif rest.startswith("ff.net.2."):
out_sd[prefix + "img_mlp.2." + rest.split(".")[-1]] = v
elif rest.startswith("ff_context.net.0.proj."):
out_sd[prefix + "txt_mlp.0." + rest.split(".")[-1]] = v
elif rest.startswith("ff_context.net.2."):
out_sd[prefix + "txt_mlp.2." + rest.split(".")[-1]] = v
else:
out_sd["double_blocks.{}.{}".format(idx, rest)] = v
elif k.startswith("single_transformer_blocks."):
idx = k.split(".")[1]
rest = ".".join(k.split(".")[2:])
prefix = "single_blocks.{}.".format(idx)
if rest.startswith("norm.linear."):
out_sd[prefix + "modulation.lin." + rest.split(".")[-1]] = v
elif rest.startswith("attn.to_q."):
single_q[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("attn.to_k."):
single_k[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("attn.to_v."):
single_v[idx + "." + rest.split(".")[-1]] = v
elif rest == "attn.norm_q.weight":
out_sd[prefix + "norm.query_norm.weight"] = v
elif rest == "attn.norm_k.weight":
out_sd[prefix + "norm.key_norm.weight"] = v
elif rest.startswith("proj_mlp."):
single_mlp[idx + "." + rest.split(".")[-1]] = v
elif rest.startswith("proj_out."):
out_sd[prefix + "linear2." + rest.split(".")[-1]] = v
else:
out_sd["single_blocks.{}.{}".format(idx, rest)] = v
elif k == "x_embedder.weight" or k == "x_embedder.bias":
out_sd["img_in." + k.split(".")[-1]] = v
elif k == "context_embedder.weight" or k == "context_embedder.bias":
out_sd["txt_in." + k.split(".")[-1]] = v
elif k.startswith("time_embed.timestep_embedder.linear_1."):
out_sd["time_in.in_layer." + k.split(".")[-1]] = v
elif k.startswith("time_embed.timestep_embedder.linear_2."):
out_sd["time_in.out_layer." + k.split(".")[-1]] = v
elif k.startswith("norm_out.linear."):
# HF AdaLayerNormContinuous stores [scale | shift] but ComfyUI
# LastLayer expects [shift | scale], so swap the two halves.
half = v.shape[0] // 2
v = torch.cat([v[half:], v[:half]], dim=0)
out_sd["final_layer.adaLN_modulation.1." + k.split(".")[-1]] = v
elif k == "proj_out.weight" or k == "proj_out.bias":
out_sd["final_layer.linear." + k.split(".")[-1]] = v
else:
out_sd[k] = v
for suffix in ["weight", "bias"]:
for idx in sorted(set(x.split(".")[0] for x in double_q)):
qk = idx + "." + suffix
if qk in double_q and qk in double_k and qk in double_v:
out_sd["double_blocks.{}.img_attn.qkv.{}".format(idx, suffix)] = torch.cat([double_q[qk], double_k[qk], double_v[qk]], dim=0)
if qk in double_tq and qk in double_tk and qk in double_tv:
out_sd["double_blocks.{}.txt_attn.qkv.{}".format(idx, suffix)] = torch.cat([double_tq[qk], double_tk[qk], double_tv[qk]], dim=0)
for idx in sorted(set(x.split(".")[0] for x in single_q)):
qk = idx + "." + suffix
if qk in single_q and qk in single_k and qk in single_v and qk in single_mlp:
out_sd["single_blocks.{}.linear1.{}".format(idx, suffix)] = torch.cat([single_q[qk], single_k[qk], single_v[qk], single_mlp[qk]], dim=0)
return out_sd
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LongCatImage(self, device=device)
return out

View File

@ -5,42 +5,40 @@ from comfy.model_detection import detect_unet_config, model_config_from_unet_con
import comfy.supported_models
def _make_longcat_diffusers_sd():
"""Minimal Diffusers-format state dict that triggers the LongCat-Image detection path."""
def _make_longcat_comfyui_sd():
"""Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights."""
sd = {}
H = 3072 # hidden_size (matches real LongCat-Image)
H = 3072
C_IN = 16
C_CTX = 3584 # context_in_dim that distinguishes LongCat from standard Flux (4096)
C_CTX = 3584
sd["x_embedder.weight"] = torch.empty(H, C_IN * 4)
sd["x_embedder.bias"] = torch.empty(H)
sd["context_embedder.weight"] = torch.empty(H, C_CTX)
sd["context_embedder.bias"] = torch.empty(H)
sd["img_in.weight"] = torch.empty(H, C_IN * 4)
sd["img_in.bias"] = torch.empty(H)
sd["txt_in.weight"] = torch.empty(H, C_CTX)
sd["txt_in.bias"] = torch.empty(H)
sd["time_embed.timestep_embedder.linear_1.weight"] = torch.empty(H, 256)
sd["time_embed.timestep_embedder.linear_1.bias"] = torch.empty(H)
sd["time_embed.timestep_embedder.linear_2.weight"] = torch.empty(H, H)
sd["time_embed.timestep_embedder.linear_2.bias"] = torch.empty(H)
sd["time_in.in_layer.weight"] = torch.empty(H, 256)
sd["time_in.in_layer.bias"] = torch.empty(H)
sd["time_in.out_layer.weight"] = torch.empty(H, H)
sd["time_in.out_layer.bias"] = torch.empty(H)
sd["norm_out.linear.weight"] = torch.empty(2 * H, H)
sd["norm_out.linear.bias"] = torch.empty(2 * H)
sd["proj_out.weight"] = torch.empty(C_IN * 4, H)
sd["proj_out.bias"] = torch.empty(C_IN * 4)
sd["final_layer.adaLN_modulation.1.weight"] = torch.empty(2 * H, H)
sd["final_layer.adaLN_modulation.1.bias"] = torch.empty(2 * H)
sd["final_layer.linear.weight"] = torch.empty(C_IN * 4, H)
sd["final_layer.linear.bias"] = torch.empty(C_IN * 4)
# Need enough transformer_blocks and single_transformer_blocks for count_blocks
# and for the required_keys check (single_transformer_blocks.10.*)
for i in range(19):
sd[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.empty(H, H)
sd[f"transformer_blocks.{i}.norm1.linear.weight"] = torch.empty(H)
sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128)
sd[f"double_blocks.{i}.img_attn.qkv.weight"] = torch.empty(3 * H, H)
sd[f"double_blocks.{i}.img_mod.lin.weight"] = torch.empty(H, H)
for i in range(38):
sd[f"single_transformer_blocks.{i}.attn.to_q.weight"] = torch.empty(H, H)
sd[f"single_transformer_blocks.{i}.norm.linear.weight"] = torch.empty(H)
sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H)
return sd
def _make_flux_schnell_comfyui_sd():
"""Minimal ComfyUI-format state dict that triggers the standard Flux detection path."""
"""Minimal ComfyUI-format state dict for standard Flux Schnell."""
sd = {}
H = 3072
C_IN = 16
@ -67,18 +65,16 @@ class TestModelDetectionSpecificity:
def test_longcat_wins_regardless_of_list_order(self):
"""Specificity logic must pick LongCatImage even when FluxSchnell appears first."""
sd = _make_longcat_diffusers_sd()
sd = _make_longcat_comfyui_sd()
unet_config = detect_unet_config(sd, "")
longcat_cls = comfy.supported_models.LongCatImage
schnell_cls = comfy.supported_models.FluxSchnell
# Order A: FluxSchnell before LongCatImage
order_a = [schnell_cls, longcat_cls]
# Order B: LongCatImage before FluxSchnell
order_b = [longcat_cls, schnell_cls]
for label, order in [("schnell-first", order_a), ("longcat-first", order_b)]:
for label, order in [
("schnell-first", [schnell_cls, longcat_cls]),
("longcat-first", [longcat_cls, schnell_cls]),
]:
with patch.object(comfy.supported_models, "models", order):
result = model_config_from_unet_config(unet_config, sd)
assert result is not None, f"No match with order {label}"
@ -86,29 +82,31 @@ class TestModelDetectionSpecificity:
f"Expected LongCatImage with order {label}, got {type(result).__name__}"
)
def test_longcat_diffusers_detected_as_longcat(self):
sd = _make_longcat_diffusers_sd()
def test_longcat_comfyui_detected_as_longcat(self):
sd = _make_longcat_comfyui_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "flux"
assert unet_config["context_in_dim"] == 3584
assert unet_config["vec_in_dim"] is None
assert unet_config["guidance_embed"] is False
assert unet_config["txt_ids_dims"] == [1, 2]
model_config = model_config_from_unet_config(unet_config, sd)
assert model_config is not None
assert type(model_config).__name__ == "LongCatImage"
def test_longcat_process_unet_state_dict_converts_keys(self):
sd = _make_longcat_diffusers_sd()
def test_longcat_comfyui_keys_pass_through_unchanged(self):
"""Pre-converted weights should not be transformed by process_unet_state_dict."""
sd = _make_longcat_comfyui_sd()
unet_config = detect_unet_config(sd, "")
model_config = model_config_from_unet_config(unet_config, sd)
converted = model_config.process_unet_state_dict(dict(sd))
assert "img_in.weight" in converted
assert "img_in.bias" in converted
assert "txt_in.weight" in converted
assert "x_embedder.weight" not in converted
assert "context_embedder.weight" not in converted
processed = model_config.process_unet_state_dict(dict(sd))
assert "img_in.weight" in processed
assert "txt_in.weight" in processed
assert "time_in.in_layer.weight" in processed
assert "final_layer.linear.weight" in processed
def test_flux_schnell_comfyui_detected_as_flux_schnell(self):
sd = _make_flux_schnell_comfyui_sd()