ComfyUI/comfy_extras/nodes_seedvr.py
2026-07-02 22:59:38 -04:00

420 lines
18 KiB
Python

from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import torch
import comfy.model_management
from comfy.ldm.seedvr.color_fix import (
adain_color_transfer,
lab_color_transfer,
wavelet_color_transfer,
)
from comfy.ldm.seedvr.constants import (
SEEDVR2_ADAIN_SCALE_MULTIPLIER,
SEEDVR2_COLOR_MEM_HEADROOM,
SEEDVR2_DTYPE_BYTES_FLOOR,
SEEDVR2_LAB_SCALE_MULTIPLIER,
SEEDVR2_LATENT_CHANNELS,
SEEDVR2_OOM_BACKOFF_DIVISOR,
SEEDVR2_WAVELET_SCALE_MULTIPLIER,
)
from torchvision.transforms import functional as TVF
from torchvision.transforms.functional import InterpolationMode
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = "SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
_ATTR_MISSING = object()
def _resolve_seedvr2_diffusion_model(model):
inner = getattr(model, "model", _ATTR_MISSING)
if inner is _ATTR_MISSING:
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input has no 'model' attribute "
f"(got type {type(model).__name__})."
)
if inner is None:
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input.model is None "
f"(input type {type(model).__name__})."
)
diffusion_model = getattr(inner, "diffusion_model", _ATTR_MISSING)
if diffusion_model is _ATTR_MISSING:
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model' has no "
f"'diffusion_model' attribute (got type {type(inner).__name__})."
)
if diffusion_model is None:
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model.diffusion_model' "
f"is None (model.model type {type(inner).__name__})."
)
return diffusion_model
def div_pad(image, factor):
height_factor, width_factor = factor
height, width = image.shape[-2:]
pad_height = (height_factor - (height % height_factor)) % height_factor
pad_width = (width_factor - (width % width_factor)) % width_factor
if pad_height == 0 and pad_width == 0:
return image
padding = (0, pad_width, 0, pad_height)
return torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
def cut_videos(videos):
t = videos.size(1)
if t < 1:
raise ValueError("SeedVR2Preprocess expected at least one frame.")
if t == 1:
return videos
if t <= 4:
padding = videos[:, -1:].repeat(1, 4 - t + 1, 1, 1, 1)
return torch.cat([videos, padding], dim=1)
if (t - 1) % 4 == 0:
return videos
padding = videos[:, -1:].repeat(1, 4 - ((t - 1) % 4), 1, 1, 1)
videos = torch.cat([videos, padding], dim=1)
if (videos.size(1) - 1) % 4 != 0:
raise ValueError(f"SeedVR2Preprocess failed to pad video length to 4n+1; got {videos.size(1)} frames.")
return videos
def _seedvr2_input_shorter_edge(images, node_name):
if images.dim() == 4:
return min(images.shape[1], images.shape[2])
if images.dim() == 5:
return min(images.shape[2], images.shape[3])
raise ValueError(
f"{node_name}: expected 4-D or 5-D IMAGE tensor, "
f"got shape {tuple(images.shape)}"
)
def _seedvr2_pad(images, upscaled_shorter_edge, node_name):
if upscaled_shorter_edge < 2:
raise ValueError(
f"{node_name}: input shorter edge must be at least 2 pixels; "
f"got {upscaled_shorter_edge}."
)
if images.shape[-1] > 3:
images = images[..., :3]
if images.dim() == 4:
# Comfy video components arrive as a 4-D IMAGE frame sequence:
# (frames, H, W, C). SeedVR2 consumes that as one video.
images = images.unsqueeze(0)
elif images.dim() != 5:
raise ValueError(
f"{node_name}: expected 4-D or 5-D IMAGE tensor, "
f"got shape {tuple(images.shape)}"
)
images = images.permute(0, 1, 4, 2, 3)
b, t, c, h, w = images.shape
images = images.reshape(b * t, c, h, w)
images = torch.clamp(images, 0.0, 1.0)
images = div_pad(images, (16, 16))
_, _, new_h, new_w = images.shape
images = images.reshape(b, t, c, new_h, new_w)
images = cut_videos(images)
images_bthwc = images.permute(0, 1, 3, 4, 2).contiguous()
return io.NodeOutput(images_bthwc)
class SeedVR2Preprocess(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2Preprocess",
display_name="Pre-Process SeedVR2 Input",
category="image/upscaling",
description="Pad a resized image for SeedVR2 model. Alpha channel is dropped. The node Post-Process SeedVR2 Output re-applies it from the original resized image.",
search_aliases=["seedvr2", "upscale", "video upscale", "pad", "preprocess"],
inputs=[
io.Image.Input("resized_images", tooltip="The resized image to process."),
],
outputs=[
io.Image.Output("images", tooltip="The padded image for VAE encoding."),
]
)
@classmethod
def execute(cls, resized_images):
upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess")
return _seedvr2_pad(
resized_images, upscaled_shorter_edge, "SeedVR2Preprocess",
)
class SeedVR2PostProcessing(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2PostProcessing",
display_name="Post-Process SeedVR2 Output",
category="image/upscaling",
description="Align the generated image with the original resized image and apply color correction.",
search_aliases=["seedvr2", "upscale", "color correction", "color match", "postprocess"],
inputs=[
io.Image.Input("images", tooltip="The generated image to process."),
io.Image.Input("original_resized_images", tooltip="The original resized image before pre-processing, used as reference."),
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="Method to match the generated image colors to the original image. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."),
],
outputs=[io.Image.Output(display_name="images", tooltip="The aligned, color-corrected image.")],
)
@classmethod
def execute(cls, images, original_resized_images, color_correction_method):
alpha_input = None
if original_resized_images.shape[-1] == 4:
alpha_input = original_resized_images[..., 3:4]
original_resized_images = original_resized_images[..., :3]
decoded_5d, decoded_was_4d = cls._as_bthwc(images)
reference_full, _ = cls._as_bthwc(original_resized_images)
decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full)
b = min(decoded_5d.shape[0], reference_full.shape[0])
t = min(decoded_5d.shape[1], reference_full.shape[1])
reference_h = reference_full.shape[2]
reference_w = reference_full.shape[3]
decoded_5d = decoded_5d[:b, :t, :, :, :]
target_h = min(decoded_5d.shape[2], reference_h)
target_w = min(decoded_5d.shape[3], reference_w)
decoded_5d = decoded_5d[:, :, :target_h, :target_w, :]
if color_correction_method in ("lab", "wavelet", "adain"):
reference_5d = reference_full[:b, :t, :, :, :]
reference_5d = cls._resize_reference(reference_5d, target_h, target_w)
output_device = decoded_5d.device
decoded_raw = cls._to_seedvr2_raw(decoded_5d)
reference_raw = cls._to_seedvr2_raw(reference_5d)
decoded_flat = decoded_raw.permute(0, 1, 4, 2, 3).reshape(b * t, decoded_raw.shape[4], target_h, target_w)
reference_flat = reference_raw.permute(0, 1, 4, 2, 3).reshape(b * t, reference_raw.shape[4], target_h, target_w)
output = cls._color_transfer_chunked(
decoded_flat, reference_flat, output_device, color_correction_method,
)
output = output.reshape(b, t, output.shape[1], output.shape[2], output.shape[3]).permute(0, 1, 3, 4, 2)
output = output.add(1.0).div(2.0).clamp(0.0, 1.0)
elif color_correction_method == "none":
output = decoded_5d
else:
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
if alpha_input is not None:
alpha_5d, _ = cls._as_bthwc(alpha_input)
alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :]
output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1)
h2 = output.shape[-3] - (output.shape[-3] % 2)
w2 = output.shape[-2] - (output.shape[-2] % 2)
output = output[:, :, :h2, :w2, :]
if decoded_was_4d:
output = output.reshape(-1, output.shape[-3], output.shape[-2], output.shape[-1])
return io.NodeOutput(output)
@staticmethod
def _as_bthwc(images):
if images.ndim == 4:
return images.unsqueeze(0), True
if images.ndim == 5:
return images, False
raise ValueError(
f"SeedVR2PostProcessing: expected 4-D or 5-D IMAGE tensor, got shape {tuple(images.shape)}"
)
@staticmethod
def _restore_reference_batch_time(decoded, reference):
if decoded.shape[0] != 1:
return decoded
ref_b, ref_t = reference.shape[:2]
if ref_b < 1 or decoded.shape[1] % ref_b != 0:
return decoded
decoded_t = decoded.shape[1] // ref_b
if decoded_t < ref_t:
return decoded
return decoded.reshape(ref_b, decoded_t, decoded.shape[2], decoded.shape[3], decoded.shape[4])
@staticmethod
def _to_seedvr2_raw(images):
return images.mul(2.0).sub(1.0)
@staticmethod
def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn):
color_device = comfy.model_management.vae_device()
decoded_flat = decoded_flat.to(device=color_device)
reference_flat = reference_flat.to(device=color_device)
output = transfer_fn(decoded_flat, reference_flat)
return output.to(device=output_device)
@staticmethod
def _lab_color_transfer_on_vae_device(decoded_flat, reference_flat, output_device):
color_device = comfy.model_management.vae_device()
result = None
for start in range(decoded_flat.shape[0]):
decoded_frame = decoded_flat[start:start + 1].to(device=color_device).clone()
reference_frame = reference_flat[start:start + 1].to(device=color_device).clone()
output = lab_color_transfer(decoded_frame, reference_frame).to(device=output_device)
if result is None:
result = torch.empty(
(decoded_flat.shape[0],) + tuple(output.shape[1:]),
device=output_device,
dtype=output.dtype,
)
result[start:start + 1].copy_(output)
if result is None:
raise ValueError("SeedVR2PostProcessing: LAB color correction requires at least one frame.")
return result
@classmethod
def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method):
chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method)
while True:
try:
return cls._run_color_transfer_chunks(
decoded_flat, reference_flat, output_device, color_correction_method, chunk_size,
)
except Exception as e:
comfy.model_management.raise_non_oom(e)
if chunk_size <= 1:
raise RuntimeError(
"SeedVR2PostProcessing: color correction OOM at one frame; "
f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}."
) from e
chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
@classmethod
def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size):
result = None
for start in range(0, decoded_flat.shape[0], chunk_size):
end = min(start + chunk_size, decoded_flat.shape[0])
decoded_chunk = decoded_flat[start:end]
reference_chunk = reference_flat[start:end]
if color_correction_method == "lab":
output = cls._lab_color_transfer_on_vae_device(decoded_chunk, reference_chunk, output_device)
elif color_correction_method == "wavelet":
output = cls._color_transfer_on_vae_device(
decoded_chunk, reference_chunk, output_device, wavelet_color_transfer,
)
else:
output = cls._color_transfer_on_vae_device(
decoded_chunk, reference_chunk, output_device, adain_color_transfer,
)
if result is None:
result = torch.empty(
(decoded_flat.shape[0],) + tuple(output.shape[1:]),
device=output_device,
dtype=output.dtype,
)
result[start:end].copy_(output)
if result is None:
raise ValueError("SeedVR2PostProcessing: color correction requires at least one frame.")
return result
@classmethod
def _estimate_color_correction_chunk_size(cls, decoded_flat, color_correction_method):
multiplier = cls._color_correction_memory_multiplier(color_correction_method)
frames = decoded_flat.shape[0]
_, channels, height, width = decoded_flat.shape
dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR)
bytes_per_frame = height * width * channels * dtype_bytes * multiplier
if bytes_per_frame <= 0:
return frames
color_device = comfy.model_management.vae_device()
free_memory = comfy.model_management.get_free_memory(color_device)
chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame)
return max(1, min(frames, chunk_size))
@staticmethod
def _color_correction_memory_multiplier(color_correction_method):
if color_correction_method == "lab":
return SEEDVR2_LAB_SCALE_MULTIPLIER
if color_correction_method == "wavelet":
return SEEDVR2_WAVELET_SCALE_MULTIPLIER
if color_correction_method == "adain":
return SEEDVR2_ADAIN_SCALE_MULTIPLIER
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
@staticmethod
def _resize_reference(reference, height, width):
if reference.shape[2] == height and reference.shape[3] == width:
return reference
b, t = reference.shape[:2]
reference_flat = reference.permute(0, 1, 4, 2, 3).reshape(b * t, reference.shape[4], reference.shape[2], reference.shape[3])
resized = TVF.resize(
reference_flat,
size=(height, width),
interpolation=InterpolationMode.BICUBIC,
antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"),
)
return resized.reshape(b, t, resized.shape[1], height, width).permute(0, 1, 3, 4, 2)
class SeedVR2Conditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2Conditioning",
display_name="Apply SeedVR2 Conditioning",
category="conditioning",
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
search_aliases=["seedvr2", "upscale", "conditioning"],
inputs=[
io.Model.Input("model", tooltip="The SeedVR2 model."),
io.Latent.Input("vae_conditioning", display_name="latent"),
],
outputs=[
io.Conditioning.Output(display_name="positive", tooltip="The positive conditioning for sampling."),
io.Conditioning.Output(display_name="negative", tooltip="The negative conditioning for sampling."),
],
)
@classmethod
def execute(cls, model, vae_conditioning) -> io.NodeOutput:
vae_conditioning = vae_conditioning["samples"]
if vae_conditioning.ndim != 5:
raise ValueError(
"SeedVR2Conditioning expects a 5-D VAE latent in Comfy "
f"channel-first layout; got shape {tuple(vae_conditioning.shape)}."
)
if vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS:
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS:
raise ValueError(
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
f"got channel-last shape {tuple(vae_conditioning.shape)}."
)
raise ValueError(
"SeedVR2Conditioning expects SeedVR2 VAE latents with "
f"{SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(vae_conditioning.shape)}."
)
vae_conditioning = vae_conditioning.movedim(1, -1).contiguous()
model = _resolve_seedvr2_diffusion_model(model)
pos_cond = model.positive_conditioning
neg_cond = model.negative_conditioning
mask = vae_conditioning.new_ones(vae_conditioning.shape[:-1] + (1,))
condition = torch.cat((vae_conditioning, mask), dim=-1)
condition = condition.movedim(-1, 1)
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
return io.NodeOutput(positive, negative)
class SeedVRExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SeedVR2Conditioning,
SeedVR2Preprocess,
SeedVR2PostProcessing,
]
async def comfy_entrypoint() -> SeedVRExtension:
return SeedVRExtension()