mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
Remove sparkvsr related code.
This commit is contained in:
parent
bf2c582605
commit
a16fc7ee98
@ -1,137 +0,0 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
from typing_extensions import override
|
||||
|
||||
class SparkVSRConditioning(io.ComfyNode):
|
||||
"""Conditioning node for SparkVSR video super-resolution.
|
||||
|
||||
Encodes LQ video and optional HR reference frames through the VAE,
|
||||
builds the concat conditioning for the CogVideoX I2V model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SparkVSRConditioning",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("lq_video"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
io.Image.Input("ref_frames", optional=True),
|
||||
io.Combo.Input("ref_mode", options=["auto", "manual"], default="auto", optional=True),
|
||||
io.String.Input("ref_indices", default="", optional=True),
|
||||
io.Float.Input("ref_guidance_scale", default=1.0, min=0.0, max=10.0, step=0.1, optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, lq_video, width, height, length,
|
||||
batch_size, ref_frames=None, ref_mode="auto", ref_indices="",
|
||||
ref_guidance_scale=1.0) -> io.NodeOutput:
|
||||
|
||||
temporal_compression = 4
|
||||
latent_t = ((length - 1) // temporal_compression) + 1
|
||||
latent_h = height // 8
|
||||
latent_w = width // 8
|
||||
|
||||
# Base latent (noise will be added by KSampler)
|
||||
latent = torch.zeros(
|
||||
[batch_size, 16, latent_t, latent_h, latent_w],
|
||||
device=comfy.model_management.intermediate_device()
|
||||
)
|
||||
|
||||
# Encode LQ video → this becomes the base latent (KSampler adds noise to this)
|
||||
lq = lq_video[:length]
|
||||
lq = comfy.utils.common_upscale(lq.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
lq_latent = vae.encode(lq[:, :, :, :3])
|
||||
|
||||
# Ensure temporal dim matches
|
||||
if lq_latent.shape[2] > latent_t:
|
||||
lq_latent = lq_latent[:, :, :latent_t]
|
||||
elif lq_latent.shape[2] < latent_t:
|
||||
pad = latent_t - lq_latent.shape[2]
|
||||
lq_latent = torch.cat([lq_latent, lq_latent[:, :, -1:].repeat(1, 1, pad, 1, 1)], dim=2)
|
||||
|
||||
# Build reference latent (16ch) — goes as concat_latent_image
|
||||
# concat_cond in model_base will concatenate this with the noise (16ch) → 32ch total
|
||||
ref_latent = torch.zeros_like(lq_latent)
|
||||
|
||||
if ref_frames is not None:
|
||||
num_video_frames = lq_video.shape[0]
|
||||
|
||||
# Determine reference indices
|
||||
if ref_mode == "manual" and ref_indices.strip():
|
||||
indices = [int(x.strip()) for x in ref_indices.split(",") if x.strip()]
|
||||
else:
|
||||
indices = _select_indices(num_video_frames)
|
||||
|
||||
# Encode each reference frame and place at its temporal position.
|
||||
# SparkVSR places refs at specific latent indices, rest stays zeros.
|
||||
for ref_idx in indices:
|
||||
if ref_idx >= ref_frames.shape[0]:
|
||||
continue
|
||||
|
||||
frame = ref_frames[ref_idx:ref_idx + 1]
|
||||
frame = comfy.utils.common_upscale(frame.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
frame_latent = vae.encode(frame[:, :, :, :3])
|
||||
|
||||
target_lat_idx = ref_idx // temporal_compression
|
||||
if target_lat_idx < latent_t:
|
||||
ref_latent[:, :, target_lat_idx] = frame_latent[:, :, 0]
|
||||
|
||||
# Set ref latent as concat conditioning (16ch, model_base.concat_cond adds it to noise)
|
||||
if ref_guidance_scale != 1.0 and ref_frames is not None:
|
||||
# CFG: positive gets real refs, negative gets zero refs
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": ref_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": torch.zeros_like(ref_latent)})
|
||||
else:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": ref_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": ref_latent})
|
||||
|
||||
# LQ latent is the base — KSampler will noise it and denoise
|
||||
out_latent = {"samples": lq_latent}
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
def _select_indices(num_frames, max_refs=None):
|
||||
"""Auto-select reference frame indices (first, evenly spaced, last)."""
|
||||
if max_refs is None:
|
||||
max_refs = (num_frames - 1) // 4 + 1
|
||||
max_refs = min(max_refs, 3)
|
||||
|
||||
if num_frames <= 1:
|
||||
return [0]
|
||||
if max_refs == 1:
|
||||
return [0]
|
||||
if max_refs == 2:
|
||||
return [0, num_frames - 1]
|
||||
|
||||
mid = num_frames // 2
|
||||
return [0, mid, num_frames - 1]
|
||||
|
||||
|
||||
class CogVideoXExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SparkVSRConditioning,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> CogVideoXExtension:
|
||||
return CogVideoXExtension()
|
||||
@ -1,144 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Convert SparkVSR/CogVideoX diffusers checkpoint to ComfyUI format.
|
||||
|
||||
Usage:
|
||||
python convert_sparkvsr_to_comfy.py --model_dir path/to/sparkvsr-checkpoint \
|
||||
--output_dir ComfyUI/models/
|
||||
|
||||
This creates two files:
|
||||
- diffusion_models/cogvideox_sparkvsr.safetensors (transformer)
|
||||
- vae/cogvideox_vae.safetensors (VAE)
|
||||
|
||||
T5-XXL text encoder does not need conversion — use existing ComfyUI T5 weights.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
def remap_transformer_keys(state_dict):
|
||||
"""Remap diffusers transformer keys to ComfyUI CogVideoX naming."""
|
||||
new_sd = {}
|
||||
for k, v in state_dict.items():
|
||||
new_key = k
|
||||
|
||||
# Patch embedding
|
||||
new_key = new_key.replace("patch_embed.proj.", "patch_embed.proj.")
|
||||
new_key = new_key.replace("patch_embed.text_proj.", "patch_embed.text_proj.")
|
||||
new_key = new_key.replace("patch_embed.pos_embedding", "patch_embed.pos_embedding")
|
||||
|
||||
# Time embedding: diffusers uses time_embedding.linear_1/2, we use time_embedding_linear_1/2
|
||||
new_key = new_key.replace("time_embedding.linear_1.", "time_embedding_linear_1.")
|
||||
new_key = new_key.replace("time_embedding.linear_2.", "time_embedding_linear_2.")
|
||||
|
||||
# OFS embedding
|
||||
new_key = new_key.replace("ofs_embedding.linear_1.", "ofs_embedding_linear_1.")
|
||||
new_key = new_key.replace("ofs_embedding.linear_2.", "ofs_embedding_linear_2.")
|
||||
|
||||
# Transformer blocks: diffusers uses transformer_blocks, we use blocks
|
||||
new_key = new_key.replace("transformer_blocks.", "blocks.")
|
||||
|
||||
# Attention: diffusers uses attn1.to_q/k/v/out, we use q/k/v/attn_out
|
||||
new_key = new_key.replace(".attn1.to_q.", ".q.")
|
||||
new_key = new_key.replace(".attn1.to_k.", ".k.")
|
||||
new_key = new_key.replace(".attn1.to_v.", ".v.")
|
||||
new_key = new_key.replace(".attn1.to_out.0.", ".attn_out.")
|
||||
new_key = new_key.replace(".attn1.norm_q.", ".norm_q.")
|
||||
new_key = new_key.replace(".attn1.norm_k.", ".norm_k.")
|
||||
|
||||
# Feed-forward: diffusers uses ff.net.0.proj/ff.net.2, we use ff_proj/ff_out
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff_proj.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff_out.")
|
||||
|
||||
# Output norms
|
||||
new_key = new_key.replace("norm_final.", "norm_final.")
|
||||
new_key = new_key.replace("norm_out.linear.", "norm_out.linear.")
|
||||
new_key = new_key.replace("norm_out.norm.", "norm_out.norm.")
|
||||
|
||||
new_sd[new_key] = v
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def remap_vae_keys(state_dict):
|
||||
"""Remap diffusers VAE keys to ComfyUI CogVideoX naming.
|
||||
|
||||
The VAE architecture is directly ported so most keys should match.
|
||||
Main differences are in block naming.
|
||||
"""
|
||||
new_sd = {}
|
||||
for k, v in state_dict.items():
|
||||
new_key = k
|
||||
|
||||
# Encoder blocks
|
||||
new_key = new_key.replace("encoder.down_blocks.", "encoder.down_blocks.")
|
||||
new_key = new_key.replace("encoder.mid_block.", "encoder.mid_block.")
|
||||
|
||||
# Decoder blocks
|
||||
new_key = new_key.replace("decoder.up_blocks.", "decoder.up_blocks.")
|
||||
new_key = new_key.replace("decoder.mid_block.", "decoder.mid_block.")
|
||||
|
||||
# Resnet blocks within down/up/mid
|
||||
new_key = new_key.replace(".resnets.", ".resnets.")
|
||||
|
||||
# CausalConv3d: diffusers stores as .conv.weight inside CausalConv3d
|
||||
# Our CausalConv3d also has .conv.weight, so this should match
|
||||
|
||||
# Downsamplers/Upsamplers
|
||||
new_key = new_key.replace(".downsamplers.0.", ".downsamplers.0.")
|
||||
new_key = new_key.replace(".upsamplers.0.", ".upsamplers.0.")
|
||||
|
||||
new_sd[new_key] = v
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert SparkVSR/CogVideoX to ComfyUI format")
|
||||
parser.add_argument("--model_dir", type=str, required=True,
|
||||
help="Path to diffusers pipeline directory (contains transformer/, vae/, etc.)")
|
||||
parser.add_argument("--output_dir", type=str, default=".",
|
||||
help="Output base directory (will create diffusion_models/ and vae/ subdirs)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load transformer
|
||||
transformer_dir = os.path.join(args.model_dir, "transformer")
|
||||
print(f"Loading transformer from {transformer_dir}...")
|
||||
transformer_sd = {}
|
||||
for f in sorted(os.listdir(transformer_dir)):
|
||||
if f.endswith(".safetensors"):
|
||||
sd = load_file(os.path.join(transformer_dir, f))
|
||||
transformer_sd.update(sd)
|
||||
|
||||
transformer_sd = remap_transformer_keys(transformer_sd)
|
||||
|
||||
out_dir = os.path.join(args.output_dir, "diffusion_models")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
out_path = os.path.join(out_dir, "cogvideox_sparkvsr.safetensors")
|
||||
print(f"Saving transformer to {out_path} ({len(transformer_sd)} keys)")
|
||||
save_file(transformer_sd, out_path)
|
||||
|
||||
# Load VAE
|
||||
vae_dir = os.path.join(args.model_dir, "vae")
|
||||
print(f"Loading VAE from {vae_dir}...")
|
||||
vae_sd = {}
|
||||
for f in sorted(os.listdir(vae_dir)):
|
||||
if f.endswith(".safetensors"):
|
||||
sd = load_file(os.path.join(vae_dir, f))
|
||||
vae_sd.update(sd)
|
||||
|
||||
vae_sd = remap_vae_keys(vae_sd)
|
||||
|
||||
out_dir = os.path.join(args.output_dir, "vae")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
out_path = os.path.join(out_dir, "cogvideox_vae.safetensors")
|
||||
print(f"Saving VAE to {out_path} ({len(vae_sd)} keys)")
|
||||
save_file(vae_sd, out_path)
|
||||
|
||||
print("Done! T5-XXL text encoder does not need conversion.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user