ComfyUI/convert_sparkvsr_to_comfy.py
2026-04-10 20:28:05 +02:00

145 lines
5.5 KiB
Python

#!/usr/bin/env python3
"""Convert SparkVSR/CogVideoX diffusers checkpoint to ComfyUI format.
Usage:
python convert_sparkvsr_to_comfy.py --model_dir path/to/sparkvsr-checkpoint \
--output_dir ComfyUI/models/
This creates two files:
- diffusion_models/cogvideox_sparkvsr.safetensors (transformer)
- vae/cogvideox_vae.safetensors (VAE)
T5-XXL text encoder does not need conversion — use existing ComfyUI T5 weights.
"""
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
def remap_transformer_keys(state_dict):
"""Remap diffusers transformer keys to ComfyUI CogVideoX naming."""
new_sd = {}
for k, v in state_dict.items():
new_key = k
# Patch embedding
new_key = new_key.replace("patch_embed.proj.", "patch_embed.proj.")
new_key = new_key.replace("patch_embed.text_proj.", "patch_embed.text_proj.")
new_key = new_key.replace("patch_embed.pos_embedding", "patch_embed.pos_embedding")
# Time embedding: diffusers uses time_embedding.linear_1/2, we use time_embedding_linear_1/2
new_key = new_key.replace("time_embedding.linear_1.", "time_embedding_linear_1.")
new_key = new_key.replace("time_embedding.linear_2.", "time_embedding_linear_2.")
# OFS embedding
new_key = new_key.replace("ofs_embedding.linear_1.", "ofs_embedding_linear_1.")
new_key = new_key.replace("ofs_embedding.linear_2.", "ofs_embedding_linear_2.")
# Transformer blocks: diffusers uses transformer_blocks, we use blocks
new_key = new_key.replace("transformer_blocks.", "blocks.")
# Attention: diffusers uses attn1.to_q/k/v/out, we use q/k/v/attn_out
new_key = new_key.replace(".attn1.to_q.", ".q.")
new_key = new_key.replace(".attn1.to_k.", ".k.")
new_key = new_key.replace(".attn1.to_v.", ".v.")
new_key = new_key.replace(".attn1.to_out.0.", ".attn_out.")
new_key = new_key.replace(".attn1.norm_q.", ".norm_q.")
new_key = new_key.replace(".attn1.norm_k.", ".norm_k.")
# Feed-forward: diffusers uses ff.net.0.proj/ff.net.2, we use ff_proj/ff_out
new_key = new_key.replace(".ff.net.0.proj.", ".ff_proj.")
new_key = new_key.replace(".ff.net.2.", ".ff_out.")
# Output norms
new_key = new_key.replace("norm_final.", "norm_final.")
new_key = new_key.replace("norm_out.linear.", "norm_out.linear.")
new_key = new_key.replace("norm_out.norm.", "norm_out.norm.")
new_sd[new_key] = v
return new_sd
def remap_vae_keys(state_dict):
"""Remap diffusers VAE keys to ComfyUI CogVideoX naming.
The VAE architecture is directly ported so most keys should match.
Main differences are in block naming.
"""
new_sd = {}
for k, v in state_dict.items():
new_key = k
# Encoder blocks
new_key = new_key.replace("encoder.down_blocks.", "encoder.down_blocks.")
new_key = new_key.replace("encoder.mid_block.", "encoder.mid_block.")
# Decoder blocks
new_key = new_key.replace("decoder.up_blocks.", "decoder.up_blocks.")
new_key = new_key.replace("decoder.mid_block.", "decoder.mid_block.")
# Resnet blocks within down/up/mid
new_key = new_key.replace(".resnets.", ".resnets.")
# CausalConv3d: diffusers stores as .conv.weight inside CausalConv3d
# Our CausalConv3d also has .conv.weight, so this should match
# Downsamplers/Upsamplers
new_key = new_key.replace(".downsamplers.0.", ".downsamplers.0.")
new_key = new_key.replace(".upsamplers.0.", ".upsamplers.0.")
new_sd[new_key] = v
return new_sd
def main():
parser = argparse.ArgumentParser(description="Convert SparkVSR/CogVideoX to ComfyUI format")
parser.add_argument("--model_dir", type=str, required=True,
help="Path to diffusers pipeline directory (contains transformer/, vae/, etc.)")
parser.add_argument("--output_dir", type=str, default=".",
help="Output base directory (will create diffusion_models/ and vae/ subdirs)")
args = parser.parse_args()
# Load transformer
transformer_dir = os.path.join(args.model_dir, "transformer")
print(f"Loading transformer from {transformer_dir}...")
transformer_sd = {}
for f in sorted(os.listdir(transformer_dir)):
if f.endswith(".safetensors"):
sd = load_file(os.path.join(transformer_dir, f))
transformer_sd.update(sd)
transformer_sd = remap_transformer_keys(transformer_sd)
out_dir = os.path.join(args.output_dir, "diffusion_models")
os.makedirs(out_dir, exist_ok=True)
out_path = os.path.join(out_dir, "cogvideox_sparkvsr.safetensors")
print(f"Saving transformer to {out_path} ({len(transformer_sd)} keys)")
save_file(transformer_sd, out_path)
# Load VAE
vae_dir = os.path.join(args.model_dir, "vae")
print(f"Loading VAE from {vae_dir}...")
vae_sd = {}
for f in sorted(os.listdir(vae_dir)):
if f.endswith(".safetensors"):
sd = load_file(os.path.join(vae_dir, f))
vae_sd.update(sd)
vae_sd = remap_vae_keys(vae_sd)
out_dir = os.path.join(args.output_dir, "vae")
os.makedirs(out_dir, exist_ok=True)
out_path = os.path.join(out_dir, "cogvideox_vae.safetensors")
print(f"Saving VAE to {out_path} ({len(vae_sd)} keys)")
save_file(vae_sd, out_path)
print("Done! T5-XXL text encoder does not need conversion.")
if __name__ == "__main__":
main()