ComfyUI/comfy/ldm/moge/state_dict.py
2026-05-12 16:09:24 +03:00

95 lines
3.5 KiB
Python

"""Translate MoGe checkpoint keys to the layouts our nn.Modules use.
MoGe checkpoints embed DINOv2 with the original Meta naming
(``backbone.blocks.{i}.attn.qkv.weight``, ``ls1.gamma``, ``mlp.w12``, ...).
The shared ``comfy.image_encoders.dino2.Dinov2Model`` uses HF naming
(``encoder.layer.{i}.attention.attention.{query,key,value}.weight``,
``layer_scale1.lambda1``, ``mlp.weights_in``, ...). We rewrite keys at load
time and split the fused ``qkv`` weight into separate Q/K/V tensors.
"""
from __future__ import annotations
import re
_DINOV2_TOPLEVEL_RENAMES = {
"patch_embed.proj.weight": "embeddings.patch_embeddings.projection.weight",
"patch_embed.proj.bias": "embeddings.patch_embeddings.projection.bias",
"cls_token": "embeddings.cls_token",
"pos_embed": "embeddings.position_embeddings",
"register_tokens": "embeddings.register_tokens",
"mask_token": "embeddings.mask_token",
"norm.weight": "layernorm.weight",
"norm.bias": "layernorm.bias",
}
_BLOCK_SUFFIX_RENAMES = [
("ls1.gamma", "layer_scale1.lambda1"),
("ls2.gamma", "layer_scale2.lambda1"),
("attn.proj.", "attention.output.dense."),
("mlp.w12.", "mlp.weights_in."),
("mlp.w3.", "mlp.weights_out."),
]
_BLOCK_RE = re.compile(r"^blocks\.(\d+)\.(.+)$")
def remap_dinov2_keys(sd: dict, src_prefix: str = "") -> dict:
"""Rewrite Meta-style DINOv2 keys under ``src_prefix`` to comfy/HF naming.
Splits each fused ``attn.qkv.{weight,bias}`` into separate
``attention.attention.{query,key,value}.{weight,bias}`` tensors using a
chunk along the leading dim.
Keys that do not start with ``src_prefix`` are returned unchanged.
"""
out: dict = {}
for k, v in sd.items():
if not k.startswith(src_prefix):
out[k] = v
continue
rel = k[len(src_prefix):]
# Top-level (cls token, pos embed, patch embed, mask token, register tokens, final norm).
if rel in _DINOV2_TOPLEVEL_RENAMES:
out[src_prefix + _DINOV2_TOPLEVEL_RENAMES[rel]] = v
continue
m = _BLOCK_RE.match(rel)
if not m:
out[k] = v
continue
i, sub = m.group(1), m.group(2)
# Split fused qkv into separate q / k / v tensors.
if sub == "attn.qkv.weight" or sub == "attn.qkv.bias":
q, kw, vw = v.chunk(3, dim=0)
tail = sub.rsplit(".", 1)[1] # weight / bias
base = "{}encoder.layer.{}.attention.attention".format(src_prefix, i)
out["{}.query.{}".format(base, tail)] = q
out["{}.key.{}".format(base, tail)] = kw
out["{}.value.{}".format(base, tail)] = vw
continue
for old, new in _BLOCK_SUFFIX_RENAMES:
sub = sub.replace(old, new)
out["{}encoder.layer.{}.{}".format(src_prefix, i, sub)] = v
return out
def remap_moge_state_dict(sd: dict) -> dict:
"""Convert a full MoGe checkpoint state dict to the layout our modules expect.
- v1 backbone lives under ``backbone.`` -> rewrite that subtree.
- v2 backbone lives under ``encoder.backbone.`` -> rewrite that subtree.
Everything else (heads, neck, projections, image_mean/std buffers) keeps
its original key names and passes through unchanged.
"""
if any(k.startswith("encoder.backbone.") for k in sd):
return remap_dinov2_keys(sd, src_prefix="encoder.backbone.")
return remap_dinov2_keys(sd, src_prefix="backbone.")