mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-04 05:31:03 +08:00
420 lines
18 KiB
Python
420 lines
18 KiB
Python
from typing_extensions import override
|
|
from comfy_api.latest import ComfyExtension, io
|
|
import torch
|
|
|
|
import comfy.model_management
|
|
from comfy.ldm.seedvr.color_fix import (
|
|
adain_color_transfer,
|
|
lab_color_transfer,
|
|
wavelet_color_transfer,
|
|
)
|
|
from comfy.ldm.seedvr.constants import (
|
|
SEEDVR2_ADAIN_SCALE_MULTIPLIER,
|
|
SEEDVR2_COLOR_MEM_HEADROOM,
|
|
SEEDVR2_DTYPE_BYTES_FLOOR,
|
|
SEEDVR2_LAB_SCALE_MULTIPLIER,
|
|
SEEDVR2_LATENT_CHANNELS,
|
|
SEEDVR2_OOM_BACKOFF_DIVISOR,
|
|
SEEDVR2_WAVELET_SCALE_MULTIPLIER,
|
|
)
|
|
|
|
from torchvision.transforms import functional as TVF
|
|
from torchvision.transforms.functional import InterpolationMode
|
|
|
|
|
|
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = "SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
|
|
_ATTR_MISSING = object()
|
|
|
|
|
|
def _resolve_seedvr2_diffusion_model(model):
|
|
inner = getattr(model, "model", _ATTR_MISSING)
|
|
if inner is _ATTR_MISSING:
|
|
raise RuntimeError(
|
|
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input has no 'model' attribute "
|
|
f"(got type {type(model).__name__})."
|
|
)
|
|
if inner is None:
|
|
raise RuntimeError(
|
|
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input.model is None "
|
|
f"(input type {type(model).__name__})."
|
|
)
|
|
diffusion_model = getattr(inner, "diffusion_model", _ATTR_MISSING)
|
|
if diffusion_model is _ATTR_MISSING:
|
|
raise RuntimeError(
|
|
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model' has no "
|
|
f"'diffusion_model' attribute (got type {type(inner).__name__})."
|
|
)
|
|
if diffusion_model is None:
|
|
raise RuntimeError(
|
|
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model.diffusion_model' "
|
|
f"is None (model.model type {type(inner).__name__})."
|
|
)
|
|
return diffusion_model
|
|
|
|
|
|
def div_pad(image, factor):
|
|
height_factor, width_factor = factor
|
|
height, width = image.shape[-2:]
|
|
|
|
pad_height = (height_factor - (height % height_factor)) % height_factor
|
|
pad_width = (width_factor - (width % width_factor)) % width_factor
|
|
|
|
if pad_height == 0 and pad_width == 0:
|
|
return image
|
|
|
|
padding = (0, pad_width, 0, pad_height)
|
|
return torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
|
|
|
|
def cut_videos(videos):
|
|
t = videos.size(1)
|
|
if t < 1:
|
|
raise ValueError("SeedVR2Preprocess expected at least one frame.")
|
|
if t == 1:
|
|
return videos
|
|
if t <= 4:
|
|
padding = videos[:, -1:].repeat(1, 4 - t + 1, 1, 1, 1)
|
|
return torch.cat([videos, padding], dim=1)
|
|
if (t - 1) % 4 == 0:
|
|
return videos
|
|
padding = videos[:, -1:].repeat(1, 4 - ((t - 1) % 4), 1, 1, 1)
|
|
videos = torch.cat([videos, padding], dim=1)
|
|
if (videos.size(1) - 1) % 4 != 0:
|
|
raise ValueError(f"SeedVR2Preprocess failed to pad video length to 4n+1; got {videos.size(1)} frames.")
|
|
return videos
|
|
|
|
def _seedvr2_input_shorter_edge(images, node_name):
|
|
if images.dim() == 4:
|
|
return min(images.shape[1], images.shape[2])
|
|
if images.dim() == 5:
|
|
return min(images.shape[2], images.shape[3])
|
|
raise ValueError(
|
|
f"{node_name}: expected 4-D or 5-D IMAGE tensor, "
|
|
f"got shape {tuple(images.shape)}"
|
|
)
|
|
|
|
|
|
def _seedvr2_pad(images, upscaled_shorter_edge, node_name):
|
|
if upscaled_shorter_edge < 2:
|
|
raise ValueError(
|
|
f"{node_name}: input shorter edge must be at least 2 pixels; "
|
|
f"got {upscaled_shorter_edge}."
|
|
)
|
|
if images.shape[-1] > 3:
|
|
images = images[..., :3]
|
|
if images.dim() == 4:
|
|
# Comfy video components arrive as a 4-D IMAGE frame sequence:
|
|
# (frames, H, W, C). SeedVR2 consumes that as one video.
|
|
images = images.unsqueeze(0)
|
|
elif images.dim() != 5:
|
|
raise ValueError(
|
|
f"{node_name}: expected 4-D or 5-D IMAGE tensor, "
|
|
f"got shape {tuple(images.shape)}"
|
|
)
|
|
images = images.permute(0, 1, 4, 2, 3)
|
|
|
|
b, t, c, h, w = images.shape
|
|
images = images.reshape(b * t, c, h, w)
|
|
|
|
images = torch.clamp(images, 0.0, 1.0)
|
|
images = div_pad(images, (16, 16))
|
|
_, _, new_h, new_w = images.shape
|
|
|
|
images = images.reshape(b, t, c, new_h, new_w)
|
|
images = cut_videos(images)
|
|
images_bthwc = images.permute(0, 1, 3, 4, 2).contiguous()
|
|
|
|
return io.NodeOutput(images_bthwc)
|
|
|
|
|
|
class SeedVR2Preprocess(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SeedVR2Preprocess",
|
|
display_name="Pre-Process SeedVR2 Input",
|
|
category="image/upscaling",
|
|
description="Pad a resized image for SeedVR2 model. Alpha channel is dropped. The node Post-Process SeedVR2 Output re-applies it from the original resized image.",
|
|
search_aliases=["seedvr2", "upscale", "video upscale", "pad", "preprocess"],
|
|
inputs=[
|
|
io.Image.Input("resized_images", tooltip="The resized image to process."),
|
|
],
|
|
outputs=[
|
|
io.Image.Output("images", tooltip="The padded image for VAE encoding."),
|
|
]
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, resized_images):
|
|
upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess")
|
|
return _seedvr2_pad(
|
|
resized_images, upscaled_shorter_edge, "SeedVR2Preprocess",
|
|
)
|
|
|
|
|
|
class SeedVR2PostProcessing(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SeedVR2PostProcessing",
|
|
display_name="Post-Process SeedVR2 Output",
|
|
category="image/upscaling",
|
|
description="Align the generated image with the original resized image and apply color correction.",
|
|
search_aliases=["seedvr2", "upscale", "color correction", "color match", "postprocess"],
|
|
inputs=[
|
|
io.Image.Input("images", tooltip="The generated image to process."),
|
|
io.Image.Input("original_resized_images", tooltip="The original resized image before pre-processing, used as reference."),
|
|
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="Method to match the generated image colors to the original image. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."),
|
|
],
|
|
outputs=[io.Image.Output(display_name="images", tooltip="The aligned, color-corrected image.")],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, images, original_resized_images, color_correction_method):
|
|
alpha_input = None
|
|
if original_resized_images.shape[-1] == 4:
|
|
alpha_input = original_resized_images[..., 3:4]
|
|
original_resized_images = original_resized_images[..., :3]
|
|
decoded_5d, decoded_was_4d = cls._as_bthwc(images)
|
|
reference_full, _ = cls._as_bthwc(original_resized_images)
|
|
decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full)
|
|
|
|
b = min(decoded_5d.shape[0], reference_full.shape[0])
|
|
t = min(decoded_5d.shape[1], reference_full.shape[1])
|
|
reference_h = reference_full.shape[2]
|
|
reference_w = reference_full.shape[3]
|
|
|
|
decoded_5d = decoded_5d[:b, :t, :, :, :]
|
|
target_h = min(decoded_5d.shape[2], reference_h)
|
|
target_w = min(decoded_5d.shape[3], reference_w)
|
|
decoded_5d = decoded_5d[:, :, :target_h, :target_w, :]
|
|
if color_correction_method in ("lab", "wavelet", "adain"):
|
|
reference_5d = reference_full[:b, :t, :, :, :]
|
|
reference_5d = cls._resize_reference(reference_5d, target_h, target_w)
|
|
output_device = decoded_5d.device
|
|
decoded_raw = cls._to_seedvr2_raw(decoded_5d)
|
|
reference_raw = cls._to_seedvr2_raw(reference_5d)
|
|
decoded_flat = decoded_raw.permute(0, 1, 4, 2, 3).reshape(b * t, decoded_raw.shape[4], target_h, target_w)
|
|
reference_flat = reference_raw.permute(0, 1, 4, 2, 3).reshape(b * t, reference_raw.shape[4], target_h, target_w)
|
|
output = cls._color_transfer_chunked(
|
|
decoded_flat, reference_flat, output_device, color_correction_method,
|
|
)
|
|
output = output.reshape(b, t, output.shape[1], output.shape[2], output.shape[3]).permute(0, 1, 3, 4, 2)
|
|
output = output.add(1.0).div(2.0).clamp(0.0, 1.0)
|
|
elif color_correction_method == "none":
|
|
output = decoded_5d
|
|
else:
|
|
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
|
|
|
|
if alpha_input is not None:
|
|
alpha_5d, _ = cls._as_bthwc(alpha_input)
|
|
alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :]
|
|
output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1)
|
|
h2 = output.shape[-3] - (output.shape[-3] % 2)
|
|
w2 = output.shape[-2] - (output.shape[-2] % 2)
|
|
output = output[:, :, :h2, :w2, :]
|
|
if decoded_was_4d:
|
|
output = output.reshape(-1, output.shape[-3], output.shape[-2], output.shape[-1])
|
|
return io.NodeOutput(output)
|
|
|
|
@staticmethod
|
|
def _as_bthwc(images):
|
|
if images.ndim == 4:
|
|
return images.unsqueeze(0), True
|
|
if images.ndim == 5:
|
|
return images, False
|
|
raise ValueError(
|
|
f"SeedVR2PostProcessing: expected 4-D or 5-D IMAGE tensor, got shape {tuple(images.shape)}"
|
|
)
|
|
|
|
@staticmethod
|
|
def _restore_reference_batch_time(decoded, reference):
|
|
if decoded.shape[0] != 1:
|
|
return decoded
|
|
ref_b, ref_t = reference.shape[:2]
|
|
if ref_b < 1 or decoded.shape[1] % ref_b != 0:
|
|
return decoded
|
|
decoded_t = decoded.shape[1] // ref_b
|
|
if decoded_t < ref_t:
|
|
return decoded
|
|
return decoded.reshape(ref_b, decoded_t, decoded.shape[2], decoded.shape[3], decoded.shape[4])
|
|
|
|
@staticmethod
|
|
def _to_seedvr2_raw(images):
|
|
return images.mul(2.0).sub(1.0)
|
|
|
|
@staticmethod
|
|
def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn):
|
|
color_device = comfy.model_management.vae_device()
|
|
decoded_flat = decoded_flat.to(device=color_device)
|
|
reference_flat = reference_flat.to(device=color_device)
|
|
output = transfer_fn(decoded_flat, reference_flat)
|
|
return output.to(device=output_device)
|
|
|
|
@staticmethod
|
|
def _lab_color_transfer_on_vae_device(decoded_flat, reference_flat, output_device):
|
|
color_device = comfy.model_management.vae_device()
|
|
result = None
|
|
for start in range(decoded_flat.shape[0]):
|
|
decoded_frame = decoded_flat[start:start + 1].to(device=color_device).clone()
|
|
reference_frame = reference_flat[start:start + 1].to(device=color_device).clone()
|
|
output = lab_color_transfer(decoded_frame, reference_frame).to(device=output_device)
|
|
if result is None:
|
|
result = torch.empty(
|
|
(decoded_flat.shape[0],) + tuple(output.shape[1:]),
|
|
device=output_device,
|
|
dtype=output.dtype,
|
|
)
|
|
result[start:start + 1].copy_(output)
|
|
if result is None:
|
|
raise ValueError("SeedVR2PostProcessing: LAB color correction requires at least one frame.")
|
|
return result
|
|
|
|
@classmethod
|
|
def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method):
|
|
chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method)
|
|
while True:
|
|
try:
|
|
return cls._run_color_transfer_chunks(
|
|
decoded_flat, reference_flat, output_device, color_correction_method, chunk_size,
|
|
)
|
|
except Exception as e:
|
|
comfy.model_management.raise_non_oom(e)
|
|
if chunk_size <= 1:
|
|
raise RuntimeError(
|
|
"SeedVR2PostProcessing: color correction OOM at one frame; "
|
|
f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}."
|
|
) from e
|
|
chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
|
|
|
|
@classmethod
|
|
def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size):
|
|
result = None
|
|
for start in range(0, decoded_flat.shape[0], chunk_size):
|
|
end = min(start + chunk_size, decoded_flat.shape[0])
|
|
decoded_chunk = decoded_flat[start:end]
|
|
reference_chunk = reference_flat[start:end]
|
|
if color_correction_method == "lab":
|
|
output = cls._lab_color_transfer_on_vae_device(decoded_chunk, reference_chunk, output_device)
|
|
elif color_correction_method == "wavelet":
|
|
output = cls._color_transfer_on_vae_device(
|
|
decoded_chunk, reference_chunk, output_device, wavelet_color_transfer,
|
|
)
|
|
else:
|
|
output = cls._color_transfer_on_vae_device(
|
|
decoded_chunk, reference_chunk, output_device, adain_color_transfer,
|
|
)
|
|
if result is None:
|
|
result = torch.empty(
|
|
(decoded_flat.shape[0],) + tuple(output.shape[1:]),
|
|
device=output_device,
|
|
dtype=output.dtype,
|
|
)
|
|
result[start:end].copy_(output)
|
|
if result is None:
|
|
raise ValueError("SeedVR2PostProcessing: color correction requires at least one frame.")
|
|
return result
|
|
|
|
@classmethod
|
|
def _estimate_color_correction_chunk_size(cls, decoded_flat, color_correction_method):
|
|
multiplier = cls._color_correction_memory_multiplier(color_correction_method)
|
|
frames = decoded_flat.shape[0]
|
|
_, channels, height, width = decoded_flat.shape
|
|
dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR)
|
|
bytes_per_frame = height * width * channels * dtype_bytes * multiplier
|
|
if bytes_per_frame <= 0:
|
|
return frames
|
|
color_device = comfy.model_management.vae_device()
|
|
free_memory = comfy.model_management.get_free_memory(color_device)
|
|
chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame)
|
|
return max(1, min(frames, chunk_size))
|
|
|
|
@staticmethod
|
|
def _color_correction_memory_multiplier(color_correction_method):
|
|
if color_correction_method == "lab":
|
|
return SEEDVR2_LAB_SCALE_MULTIPLIER
|
|
if color_correction_method == "wavelet":
|
|
return SEEDVR2_WAVELET_SCALE_MULTIPLIER
|
|
if color_correction_method == "adain":
|
|
return SEEDVR2_ADAIN_SCALE_MULTIPLIER
|
|
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
|
|
|
|
@staticmethod
|
|
def _resize_reference(reference, height, width):
|
|
if reference.shape[2] == height and reference.shape[3] == width:
|
|
return reference
|
|
b, t = reference.shape[:2]
|
|
reference_flat = reference.permute(0, 1, 4, 2, 3).reshape(b * t, reference.shape[4], reference.shape[2], reference.shape[3])
|
|
resized = TVF.resize(
|
|
reference_flat,
|
|
size=(height, width),
|
|
interpolation=InterpolationMode.BICUBIC,
|
|
antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"),
|
|
)
|
|
return resized.reshape(b, t, resized.shape[1], height, width).permute(0, 1, 3, 4, 2)
|
|
|
|
|
|
class SeedVR2Conditioning(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SeedVR2Conditioning",
|
|
display_name="Apply SeedVR2 Conditioning",
|
|
category="conditioning",
|
|
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
|
|
search_aliases=["seedvr2", "upscale", "conditioning"],
|
|
inputs=[
|
|
io.Model.Input("model", tooltip="The SeedVR2 model."),
|
|
io.Latent.Input("vae_conditioning", display_name="latent"),
|
|
],
|
|
outputs=[
|
|
io.Conditioning.Output(display_name="positive", tooltip="The positive conditioning for sampling."),
|
|
io.Conditioning.Output(display_name="negative", tooltip="The negative conditioning for sampling."),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model, vae_conditioning) -> io.NodeOutput:
|
|
|
|
vae_conditioning = vae_conditioning["samples"]
|
|
if vae_conditioning.ndim != 5:
|
|
raise ValueError(
|
|
"SeedVR2Conditioning expects a 5-D VAE latent in Comfy "
|
|
f"channel-first layout; got shape {tuple(vae_conditioning.shape)}."
|
|
)
|
|
if vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS:
|
|
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS:
|
|
raise ValueError(
|
|
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
|
|
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
|
|
f"got channel-last shape {tuple(vae_conditioning.shape)}."
|
|
)
|
|
raise ValueError(
|
|
"SeedVR2Conditioning expects SeedVR2 VAE latents with "
|
|
f"{SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(vae_conditioning.shape)}."
|
|
)
|
|
vae_conditioning = vae_conditioning.movedim(1, -1).contiguous()
|
|
model = _resolve_seedvr2_diffusion_model(model)
|
|
pos_cond = model.positive_conditioning
|
|
neg_cond = model.negative_conditioning
|
|
|
|
mask = vae_conditioning.new_ones(vae_conditioning.shape[:-1] + (1,))
|
|
condition = torch.cat((vae_conditioning, mask), dim=-1)
|
|
condition = condition.movedim(-1, 1)
|
|
|
|
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
|
|
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
|
|
|
|
return io.NodeOutput(positive, negative)
|
|
|
|
class SeedVRExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
SeedVR2Conditioning,
|
|
SeedVR2Preprocess,
|
|
SeedVR2PostProcessing,
|
|
]
|
|
|
|
async def comfy_entrypoint() -> SeedVRExtension:
|
|
return SeedVRExtension()
|