mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
145 lines
5.5 KiB
Python
145 lines
5.5 KiB
Python
#!/usr/bin/env python3
|
|
"""Convert SparkVSR/CogVideoX diffusers checkpoint to ComfyUI format.
|
|
|
|
Usage:
|
|
python convert_sparkvsr_to_comfy.py --model_dir path/to/sparkvsr-checkpoint \
|
|
--output_dir ComfyUI/models/
|
|
|
|
This creates two files:
|
|
- diffusion_models/cogvideox_sparkvsr.safetensors (transformer)
|
|
- vae/cogvideox_vae.safetensors (VAE)
|
|
|
|
T5-XXL text encoder does not need conversion — use existing ComfyUI T5 weights.
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
|
|
|
|
def remap_transformer_keys(state_dict):
|
|
"""Remap diffusers transformer keys to ComfyUI CogVideoX naming."""
|
|
new_sd = {}
|
|
for k, v in state_dict.items():
|
|
new_key = k
|
|
|
|
# Patch embedding
|
|
new_key = new_key.replace("patch_embed.proj.", "patch_embed.proj.")
|
|
new_key = new_key.replace("patch_embed.text_proj.", "patch_embed.text_proj.")
|
|
new_key = new_key.replace("patch_embed.pos_embedding", "patch_embed.pos_embedding")
|
|
|
|
# Time embedding: diffusers uses time_embedding.linear_1/2, we use time_embedding_linear_1/2
|
|
new_key = new_key.replace("time_embedding.linear_1.", "time_embedding_linear_1.")
|
|
new_key = new_key.replace("time_embedding.linear_2.", "time_embedding_linear_2.")
|
|
|
|
# OFS embedding
|
|
new_key = new_key.replace("ofs_embedding.linear_1.", "ofs_embedding_linear_1.")
|
|
new_key = new_key.replace("ofs_embedding.linear_2.", "ofs_embedding_linear_2.")
|
|
|
|
# Transformer blocks: diffusers uses transformer_blocks, we use blocks
|
|
new_key = new_key.replace("transformer_blocks.", "blocks.")
|
|
|
|
# Attention: diffusers uses attn1.to_q/k/v/out, we use q/k/v/attn_out
|
|
new_key = new_key.replace(".attn1.to_q.", ".q.")
|
|
new_key = new_key.replace(".attn1.to_k.", ".k.")
|
|
new_key = new_key.replace(".attn1.to_v.", ".v.")
|
|
new_key = new_key.replace(".attn1.to_out.0.", ".attn_out.")
|
|
new_key = new_key.replace(".attn1.norm_q.", ".norm_q.")
|
|
new_key = new_key.replace(".attn1.norm_k.", ".norm_k.")
|
|
|
|
# Feed-forward: diffusers uses ff.net.0.proj/ff.net.2, we use ff_proj/ff_out
|
|
new_key = new_key.replace(".ff.net.0.proj.", ".ff_proj.")
|
|
new_key = new_key.replace(".ff.net.2.", ".ff_out.")
|
|
|
|
# Output norms
|
|
new_key = new_key.replace("norm_final.", "norm_final.")
|
|
new_key = new_key.replace("norm_out.linear.", "norm_out.linear.")
|
|
new_key = new_key.replace("norm_out.norm.", "norm_out.norm.")
|
|
|
|
new_sd[new_key] = v
|
|
|
|
return new_sd
|
|
|
|
|
|
def remap_vae_keys(state_dict):
|
|
"""Remap diffusers VAE keys to ComfyUI CogVideoX naming.
|
|
|
|
The VAE architecture is directly ported so most keys should match.
|
|
Main differences are in block naming.
|
|
"""
|
|
new_sd = {}
|
|
for k, v in state_dict.items():
|
|
new_key = k
|
|
|
|
# Encoder blocks
|
|
new_key = new_key.replace("encoder.down_blocks.", "encoder.down_blocks.")
|
|
new_key = new_key.replace("encoder.mid_block.", "encoder.mid_block.")
|
|
|
|
# Decoder blocks
|
|
new_key = new_key.replace("decoder.up_blocks.", "decoder.up_blocks.")
|
|
new_key = new_key.replace("decoder.mid_block.", "decoder.mid_block.")
|
|
|
|
# Resnet blocks within down/up/mid
|
|
new_key = new_key.replace(".resnets.", ".resnets.")
|
|
|
|
# CausalConv3d: diffusers stores as .conv.weight inside CausalConv3d
|
|
# Our CausalConv3d also has .conv.weight, so this should match
|
|
|
|
# Downsamplers/Upsamplers
|
|
new_key = new_key.replace(".downsamplers.0.", ".downsamplers.0.")
|
|
new_key = new_key.replace(".upsamplers.0.", ".upsamplers.0.")
|
|
|
|
new_sd[new_key] = v
|
|
|
|
return new_sd
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Convert SparkVSR/CogVideoX to ComfyUI format")
|
|
parser.add_argument("--model_dir", type=str, required=True,
|
|
help="Path to diffusers pipeline directory (contains transformer/, vae/, etc.)")
|
|
parser.add_argument("--output_dir", type=str, default=".",
|
|
help="Output base directory (will create diffusion_models/ and vae/ subdirs)")
|
|
args = parser.parse_args()
|
|
|
|
# Load transformer
|
|
transformer_dir = os.path.join(args.model_dir, "transformer")
|
|
print(f"Loading transformer from {transformer_dir}...")
|
|
transformer_sd = {}
|
|
for f in sorted(os.listdir(transformer_dir)):
|
|
if f.endswith(".safetensors"):
|
|
sd = load_file(os.path.join(transformer_dir, f))
|
|
transformer_sd.update(sd)
|
|
|
|
transformer_sd = remap_transformer_keys(transformer_sd)
|
|
|
|
out_dir = os.path.join(args.output_dir, "diffusion_models")
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
out_path = os.path.join(out_dir, "cogvideox_sparkvsr.safetensors")
|
|
print(f"Saving transformer to {out_path} ({len(transformer_sd)} keys)")
|
|
save_file(transformer_sd, out_path)
|
|
|
|
# Load VAE
|
|
vae_dir = os.path.join(args.model_dir, "vae")
|
|
print(f"Loading VAE from {vae_dir}...")
|
|
vae_sd = {}
|
|
for f in sorted(os.listdir(vae_dir)):
|
|
if f.endswith(".safetensors"):
|
|
sd = load_file(os.path.join(vae_dir, f))
|
|
vae_sd.update(sd)
|
|
|
|
vae_sd = remap_vae_keys(vae_sd)
|
|
|
|
out_dir = os.path.join(args.output_dir, "vae")
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
out_path = os.path.join(out_dir, "cogvideox_vae.safetensors")
|
|
print(f"Saving VAE to {out_path} ({len(vae_sd)} keys)")
|
|
save_file(vae_sd, out_path)
|
|
|
|
print("Done! T5-XXL text encoder does not need conversion.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|