mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 08:19:32 +08:00
Add SeedVR2 workflow nodes
This commit is contained in:
parent
a7ea0c2773
commit
d54ce3d781
340
comfy/ldm/seedvr/color_fix.py
Normal file
340
comfy/ldm/seedvr/color_fix.py
Normal file
@ -0,0 +1,340 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||
from comfy.ldm.seedvr.vae import safe_interpolate_operation
|
||||
from comfy.ldm.seedvr.constants import (
|
||||
CIELAB_DELTA,
|
||||
CIELAB_KAPPA,
|
||||
D65_WHITE_X,
|
||||
D65_WHITE_Z,
|
||||
WAVELET_DECOMP_LEVELS,
|
||||
)
|
||||
|
||||
|
||||
def wavelet_blur(image: Tensor, radius):
|
||||
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
|
||||
if radius > max_safe_radius:
|
||||
radius = max_safe_radius
|
||||
|
||||
num_channels = image.shape[1]
|
||||
|
||||
kernel_vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
|
||||
|
||||
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
|
||||
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
|
||||
|
||||
return output
|
||||
|
||||
def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
|
||||
high_freq = torch.zeros_like(image)
|
||||
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
low_freq = wavelet_blur(image, radius)
|
||||
high_freq.add_(image).sub_(low_freq)
|
||||
image = low_freq
|
||||
|
||||
return high_freq, low_freq
|
||||
|
||||
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
# Resize style to match content spatial dimensions
|
||||
if len(content_feat.shape) >= 3:
|
||||
# safe_interpolate_operation handles FP16 conversion automatically
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Decompose both features into frequency components
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq # Free memory immediately
|
||||
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq # Free memory immediately
|
||||
|
||||
if content_high_freq.shape != style_low_freq.shape:
|
||||
style_low_freq = safe_interpolate_operation(
|
||||
style_low_freq,
|
||||
size=content_high_freq.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
content_high_freq.add_(style_low_freq)
|
||||
|
||||
return content_high_freq.clamp_(-1.0, 1.0)
|
||||
|
||||
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
|
||||
original_shape = source.shape
|
||||
|
||||
# Flatten
|
||||
source_flat = source.flatten()
|
||||
reference_flat = reference.flatten()
|
||||
|
||||
# Sort both arrays
|
||||
source_sorted, source_indices = torch.sort(source_flat)
|
||||
reference_sorted, _ = torch.sort(reference_flat)
|
||||
del reference_flat
|
||||
|
||||
# Quantile mapping
|
||||
n_source = len(source_sorted)
|
||||
n_reference = len(reference_sorted)
|
||||
|
||||
if n_source == n_reference:
|
||||
matched_sorted = reference_sorted
|
||||
else:
|
||||
# Interpolate reference to match source quantiles
|
||||
source_quantiles = torch.linspace(0, 1, n_source, device=device)
|
||||
ref_indices = (source_quantiles * (n_reference - 1)).long()
|
||||
ref_indices.clamp_(0, n_reference - 1)
|
||||
matched_sorted = reference_sorted[ref_indices]
|
||||
del source_quantiles, ref_indices, reference_sorted
|
||||
|
||||
del source_sorted, source_flat
|
||||
|
||||
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
|
||||
inverse_indices = torch.argsort(source_indices)
|
||||
del source_indices
|
||||
matched_flat = matched_sorted[inverse_indices]
|
||||
del matched_sorted, inverse_indices
|
||||
|
||||
return matched_flat.reshape(original_shape)
|
||||
|
||||
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of CIELAB images to RGB color space."""
|
||||
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
|
||||
|
||||
# LAB to XYZ
|
||||
fy = (L + 16.0) / 116.0
|
||||
fx = a.div(500.0).add_(fy)
|
||||
fz = fy - b / 200.0
|
||||
del L, a, b
|
||||
|
||||
# XYZ transformation
|
||||
x = torch.where(
|
||||
fx > epsilon,
|
||||
torch.pow(fx, 3.0),
|
||||
fx.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
y = torch.where(
|
||||
fy > epsilon,
|
||||
torch.pow(fy, 3.0),
|
||||
fy.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
z = torch.where(
|
||||
fz > epsilon,
|
||||
torch.pow(fz, 3.0),
|
||||
fz.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
del fx, fy, fz
|
||||
|
||||
# Apply D65 white point (in-place)
|
||||
x.mul_(D65_WHITE_X)
|
||||
# y *= 1.00000 # (no-op, skip)
|
||||
z.mul_(D65_WHITE_Z)
|
||||
|
||||
xyz = torch.stack([x, y, z], dim=1)
|
||||
del x, y, z
|
||||
|
||||
# Matrix multiplication: XYZ -> RGB
|
||||
B, C, H, W = xyz.shape
|
||||
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del xyz
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
|
||||
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
|
||||
del xyz_flat
|
||||
|
||||
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del rgb_linear_flat
|
||||
|
||||
# Apply inverse gamma correction (delinearize)
|
||||
mask = rgb_linear > 0.0031308
|
||||
rgb = torch.where(
|
||||
mask,
|
||||
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
|
||||
rgb_linear * 12.92
|
||||
)
|
||||
del mask, rgb_linear
|
||||
|
||||
return torch.clamp(rgb, 0.0, 1.0)
|
||||
|
||||
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
|
||||
# Apply sRGB gamma correction (linearize)
|
||||
mask = rgb > 0.04045
|
||||
rgb_linear = torch.where(
|
||||
mask,
|
||||
torch.pow((rgb + 0.055) / 1.055, 2.4),
|
||||
rgb / 12.92
|
||||
)
|
||||
del mask
|
||||
|
||||
# Matrix multiplication: RGB -> XYZ
|
||||
B, C, H, W = rgb_linear.shape
|
||||
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del rgb_linear
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
|
||||
xyz_flat = torch.matmul(rgb_flat, matrix.T)
|
||||
del rgb_flat
|
||||
|
||||
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del xyz_flat
|
||||
|
||||
# Normalize by D65 white point (in-place)
|
||||
xyz[:, 0].div_(D65_WHITE_X) # X
|
||||
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
|
||||
xyz[:, 2].div_(D65_WHITE_Z) # Z
|
||||
|
||||
# XYZ to LAB transformation
|
||||
epsilon_cubed = epsilon ** 3
|
||||
mask = xyz > epsilon_cubed
|
||||
f_xyz = torch.where(
|
||||
mask,
|
||||
torch.pow(xyz, 1.0 / 3.0),
|
||||
xyz.mul(kappa).add_(16.0).div_(116.0)
|
||||
)
|
||||
del xyz, mask
|
||||
|
||||
# Extract channels and compute LAB
|
||||
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
|
||||
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
|
||||
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
|
||||
del f_xyz
|
||||
|
||||
return torch.stack([L, a, b], dim=1)
|
||||
|
||||
def lab_color_transfer(
|
||||
content_feat: Tensor,
|
||||
style_feat: Tensor,
|
||||
luminance_weight: float = 0.8
|
||||
) -> Tensor:
|
||||
content_feat = wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
device = content_feat.device
|
||||
|
||||
def ensure_float32_precision(c):
|
||||
orig_dtype = c.dtype
|
||||
c = c.float()
|
||||
return c, orig_dtype
|
||||
content_feat, original_dtype = ensure_float32_precision(content_feat)
|
||||
style_feat, _ = ensure_float32_precision(style_feat)
|
||||
|
||||
rgb_to_xyz_matrix = torch.tensor([
|
||||
[0.4124564, 0.3575761, 0.1804375],
|
||||
[0.2126729, 0.7151522, 0.0721750],
|
||||
[0.0193339, 0.1191920, 0.9503041]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
xyz_to_rgb_matrix = torch.tensor([
|
||||
[ 3.2404542, -1.5371385, -0.4985314],
|
||||
[-0.9692660, 1.8760108, 0.0415560],
|
||||
[ 0.0556434, -0.2040259, 1.0572252]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
epsilon = CIELAB_DELTA
|
||||
kappa = CIELAB_KAPPA
|
||||
|
||||
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
|
||||
# Convert to LAB color space
|
||||
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del content_feat
|
||||
|
||||
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del style_feat, rgb_to_xyz_matrix
|
||||
|
||||
# Match chrominance channels (a*, b*) for accurate color transfer
|
||||
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
|
||||
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
|
||||
|
||||
# Handle luminance with weighted blending
|
||||
if luminance_weight < 1.0:
|
||||
# Partially match luminance for better overall color accuracy
|
||||
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
|
||||
# Blend: preserve some content L* for detail, adopt some style L* for color
|
||||
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
|
||||
del matched_L
|
||||
else:
|
||||
# Fully preserve content luminance
|
||||
result_L = content_lab[:, 0]
|
||||
|
||||
del content_lab, style_lab
|
||||
|
||||
# Reconstruct LAB with corrected channels
|
||||
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
|
||||
del result_L, matched_a, matched_b
|
||||
|
||||
# Convert back to RGB
|
||||
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
|
||||
del result_lab, xyz_to_rgb_matrix
|
||||
|
||||
# Convert back to [-1, 1] range (in-place)
|
||||
result = result_rgb.mul_(2.0).sub_(1.0)
|
||||
del result_rgb
|
||||
|
||||
result = result.to(original_dtype)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
return wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
|
||||
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
original_dtype = content_feat.dtype
|
||||
content_feat = content_feat.float()
|
||||
style_feat = style_feat.float()
|
||||
|
||||
b, c = content_feat.shape[:2]
|
||||
content_flat = content_feat.reshape(b, c, -1)
|
||||
style_flat = style_feat.reshape(b, c, -1)
|
||||
|
||||
content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1)
|
||||
content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||
style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1)
|
||||
style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||
del content_flat, style_flat
|
||||
|
||||
normalized = (content_feat - content_mean) / content_std
|
||||
del content_mean, content_std
|
||||
result = normalized * style_std + style_mean
|
||||
del normalized, style_mean, style_std
|
||||
|
||||
result = result.clamp_(-1.0, 1.0)
|
||||
if result.dtype != original_dtype:
|
||||
result = result.to(original_dtype)
|
||||
return result
|
||||
997
comfy_extras/nodes_seedvr.py
Normal file
997
comfy_extras/nodes_seedvr.py
Normal file
@ -0,0 +1,997 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import torch
|
||||
import math
|
||||
import logging
|
||||
from einops import rearrange
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.sample
|
||||
import comfy.samplers
|
||||
from comfy.ldm.seedvr.color_fix import (
|
||||
adain_color_transfer,
|
||||
lab_color_transfer,
|
||||
wavelet_color_transfer,
|
||||
)
|
||||
from comfy.ldm.seedvr.constants import (
|
||||
SEEDVR2_ADAIN_SCALE_MULTIPLIER,
|
||||
SEEDVR2_CHUNK_FRAMES_PER_GB,
|
||||
SEEDVR2_CHUNK_GB_MARGIN,
|
||||
SEEDVR2_COLOR_MEM_HEADROOM,
|
||||
SEEDVR2_COND_CHANNELS,
|
||||
SEEDVR2_DTYPE_BYTES_FLOOR,
|
||||
SEEDVR2_LAB_SCALE_MULTIPLIER,
|
||||
SEEDVR2_LATENT_CHANNELS,
|
||||
SEEDVR2_OOM_BACKOFF_DIVISOR,
|
||||
SEEDVR2_WAVELET_SCALE_MULTIPLIER,
|
||||
)
|
||||
|
||||
from torchvision.transforms import functional as TVF
|
||||
from torchvision.transforms import Lambda
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
|
||||
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = (
|
||||
"SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
|
||||
)
|
||||
|
||||
# Private sentinel for getattr default: distinguishes "attribute missing"
|
||||
# from "attribute present but None" so the failure message is accurate.
|
||||
_ATTR_MISSING = object()
|
||||
|
||||
|
||||
def _seedvr2_vram_seed_frames_per_chunk(free_bytes, t_pixel):
|
||||
"""Predict the largest 4n+1 pixel-frame chunk that fits in free_bytes."""
|
||||
free_gb = free_bytes / (1024 ** 3)
|
||||
predicted = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_gb - SEEDVR2_CHUNK_GB_MARGIN)
|
||||
# round (not floor) to 4n+1: the fit's central prediction lands on measured n_max
|
||||
n = round((predicted - 1) / 4)
|
||||
seed = 4 * int(n) + 1
|
||||
seed = max(1, min(seed, t_pixel))
|
||||
return seed
|
||||
|
||||
|
||||
def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk):
|
||||
"""Return stricter 4n+1 frame chunk sizes for auto OOM retries."""
|
||||
attempts = [frames_per_chunk]
|
||||
current_chunk_latent = (
|
||||
t_latent if t_pixel <= frames_per_chunk
|
||||
else (frames_per_chunk - 1) // 4 + 1
|
||||
)
|
||||
current_chunk_count = max(1, math.ceil(t_latent / current_chunk_latent))
|
||||
seen = {frames_per_chunk}
|
||||
|
||||
for target_chunks in range(max(2, current_chunk_count + 1), t_latent + 1):
|
||||
chunk_latent = max(1, math.ceil(t_latent / target_chunks))
|
||||
candidate = 4 * (chunk_latent - 1) + 1
|
||||
if candidate in seen:
|
||||
continue
|
||||
if candidate >= attempts[-1]:
|
||||
continue
|
||||
attempts.append(candidate)
|
||||
seen.add(candidate)
|
||||
|
||||
return attempts
|
||||
|
||||
|
||||
def _resolve_seedvr2_diffusion_model(model):
|
||||
"""Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message."""
|
||||
inner = getattr(model, "model", _ATTR_MISSING)
|
||||
if inner is _ATTR_MISSING:
|
||||
raise RuntimeError(
|
||||
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input has no 'model' attribute "
|
||||
f"(got type {type(model).__name__})."
|
||||
)
|
||||
if inner is None:
|
||||
raise RuntimeError(
|
||||
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input.model is None "
|
||||
f"(input type {type(model).__name__})."
|
||||
)
|
||||
diffusion_model = getattr(inner, "diffusion_model", _ATTR_MISSING)
|
||||
if diffusion_model is _ATTR_MISSING:
|
||||
raise RuntimeError(
|
||||
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model' has no "
|
||||
f"'diffusion_model' attribute (got type {type(inner).__name__})."
|
||||
)
|
||||
if diffusion_model is None:
|
||||
raise RuntimeError(
|
||||
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model.diffusion_model' "
|
||||
f"is None (model.model type {type(inner).__name__})."
|
||||
)
|
||||
return diffusion_model
|
||||
|
||||
|
||||
def _apply_rope_freqs_float32_cast(diffusion_model):
|
||||
"""Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype."""
|
||||
for module in diffusion_model.modules():
|
||||
if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'):
|
||||
if module.rope.freqs.data.dtype != torch.float32:
|
||||
module.rope.freqs.data = module.rope.freqs.data.to(torch.float32)
|
||||
|
||||
|
||||
def get_conditions(latent, latent_blur):
|
||||
t, h, w, c = latent.shape
|
||||
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
|
||||
cond[:, ..., :-1] = latent_blur[:]
|
||||
cond[:, ..., -1:] = 1.0
|
||||
return cond
|
||||
|
||||
def div_pad(image, factor):
|
||||
|
||||
height_factor, width_factor = factor
|
||||
height, width = image.shape[-2:]
|
||||
|
||||
pad_height = (height_factor - (height % height_factor)) % height_factor
|
||||
pad_width = (width_factor - (width % width_factor)) % width_factor
|
||||
|
||||
if pad_height == 0 and pad_width == 0:
|
||||
return image
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
padding = (0, pad_width, 0, pad_height)
|
||||
image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
|
||||
|
||||
return image
|
||||
|
||||
def cut_videos(videos):
|
||||
t = videos.size(1)
|
||||
if t == 1:
|
||||
return videos
|
||||
if t <= 4 :
|
||||
padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1)
|
||||
padding = torch.cat(padding, dim=1)
|
||||
videos = torch.cat([videos, padding], dim=1)
|
||||
return videos
|
||||
if (t - 1) % (4) == 0:
|
||||
return videos
|
||||
else:
|
||||
padding = [videos[:, -1].unsqueeze(1)] * (
|
||||
4 - ((t - 1) % (4))
|
||||
)
|
||||
padding = torch.cat(padding, dim=1)
|
||||
videos = torch.cat([videos, padding], dim=1)
|
||||
assert (videos.size(1) - 1) % (4) == 0
|
||||
return videos
|
||||
|
||||
def _seedvr2_input_shorter_edge(images, node_name):
|
||||
if images.dim() == 4:
|
||||
return min(images.shape[1], images.shape[2])
|
||||
if images.dim() == 5:
|
||||
return min(images.shape[2], images.shape[3])
|
||||
raise ValueError(
|
||||
f"{node_name}: expected 4-D or 5-D IMAGE tensor, "
|
||||
f"got shape {tuple(images.shape)}"
|
||||
)
|
||||
|
||||
|
||||
def _seedvr2_pad(images, upscaled_shorter_edge, node_name):
|
||||
if upscaled_shorter_edge < 2:
|
||||
raise ValueError(
|
||||
f"{node_name}: input shorter edge must be at least 2 pixels; "
|
||||
f"got {upscaled_shorter_edge}."
|
||||
)
|
||||
if images.shape[-1] > 3:
|
||||
images = images[..., :3]
|
||||
if images.dim() == 4:
|
||||
# Comfy video components arrive as a 4-D IMAGE frame sequence:
|
||||
# (frames, H, W, C). SeedVR2 consumes that as one video.
|
||||
images = images.unsqueeze(0)
|
||||
elif images.dim() != 5:
|
||||
raise ValueError(
|
||||
f"{node_name}: expected 4-D or 5-D IMAGE tensor, "
|
||||
f"got shape {tuple(images.shape)}"
|
||||
)
|
||||
images = images.permute(0, 1, 4, 2, 3)
|
||||
|
||||
b, t, c, h, w = images.shape
|
||||
images = images.reshape(b * t, c, h, w)
|
||||
|
||||
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
||||
images = clip(images)
|
||||
images = div_pad(images, (16, 16))
|
||||
_, _, new_h, new_w = images.shape
|
||||
|
||||
images = images.reshape(b, t, c, new_h, new_w)
|
||||
images = cut_videos(images)
|
||||
images_bthwc = rearrange(images, "b t c h w -> b t h w c")
|
||||
|
||||
return io.NodeOutput(images_bthwc)
|
||||
|
||||
|
||||
class SeedVR2Preprocess(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedVR2Preprocess",
|
||||
display_name="Pre-Process SeedVR2 Input",
|
||||
category="image/upscaling",
|
||||
description="Pad a resized image for SeedVR2 model. Alpha channel is dropped. The node Post-Process SeedVR2 Output re-applies it from the original resized image.",
|
||||
search_aliases=["seedvr2", "upscale", "video upscale", "pad", "preprocess"],
|
||||
inputs=[
|
||||
io.Image.Input("resized_images", tooltip="The resized image to process."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("images", tooltip="The padded image for VAE encoding."),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, resized_images):
|
||||
upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess")
|
||||
return _seedvr2_pad(
|
||||
resized_images, upscaled_shorter_edge, "SeedVR2Preprocess",
|
||||
)
|
||||
|
||||
|
||||
class SeedVR2PostProcessing(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedVR2PostProcessing",
|
||||
display_name="Post-Process SeedVR2 Output",
|
||||
category="image/upscaling",
|
||||
description="Align the generated image with the original resized image and apply color correction.",
|
||||
search_aliases=["seedvr2", "upscale", "color correction", "color match", "postprocess"],
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="The generated image to process."),
|
||||
io.Image.Input("original_resized_images", tooltip="The original resized image before pre-processing, used as reference."),
|
||||
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="Method to match the generated image colors to the original image. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."),
|
||||
],
|
||||
outputs=[io.Image.Output(display_name="images", tooltip="The aligned, color-corrected image.")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, original_resized_images, color_correction_method):
|
||||
alpha_input = None
|
||||
if original_resized_images.shape[-1] == 4:
|
||||
alpha_input = original_resized_images[..., 3:4]
|
||||
original_resized_images = original_resized_images[..., :3]
|
||||
decoded_5d, decoded_was_4d = cls._as_bthwc(images)
|
||||
reference_full, _ = cls._as_bthwc(original_resized_images)
|
||||
decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full)
|
||||
|
||||
b = min(decoded_5d.shape[0], reference_full.shape[0])
|
||||
t = min(decoded_5d.shape[1], reference_full.shape[1])
|
||||
reference_h = reference_full.shape[2]
|
||||
reference_w = reference_full.shape[3]
|
||||
|
||||
decoded_5d = decoded_5d[:b, :t, :, :, :]
|
||||
target_h = min(decoded_5d.shape[2], reference_h)
|
||||
target_w = min(decoded_5d.shape[3], reference_w)
|
||||
decoded_5d = decoded_5d[:, :, :target_h, :target_w, :]
|
||||
if color_correction_method in ("lab", "wavelet", "adain"):
|
||||
reference_5d = reference_full[:b, :t, :, :, :]
|
||||
reference_5d = cls._resize_reference(reference_5d, target_h, target_w)
|
||||
output_device = decoded_5d.device
|
||||
decoded_raw = cls._to_seedvr2_raw(decoded_5d)
|
||||
reference_raw = cls._to_seedvr2_raw(reference_5d)
|
||||
decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w")
|
||||
reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w")
|
||||
output = cls._color_transfer_chunked(
|
||||
decoded_flat, reference_flat, output_device, color_correction_method,
|
||||
)
|
||||
output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t)
|
||||
output = output.add(1.0).div(2.0).clamp(0.0, 1.0)
|
||||
elif color_correction_method == "none":
|
||||
output = decoded_5d
|
||||
else:
|
||||
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
|
||||
|
||||
if alpha_input is not None:
|
||||
alpha_5d, _ = cls._as_bthwc(alpha_input)
|
||||
alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :]
|
||||
output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1)
|
||||
h2 = output.shape[-3] - (output.shape[-3] % 2)
|
||||
w2 = output.shape[-2] - (output.shape[-2] % 2)
|
||||
output = output[:, :, :h2, :w2, :]
|
||||
if decoded_was_4d:
|
||||
output = output.reshape(-1, output.shape[-3], output.shape[-2], output.shape[-1])
|
||||
return io.NodeOutput(output)
|
||||
|
||||
@staticmethod
|
||||
def _as_bthwc(images):
|
||||
if images.ndim == 4:
|
||||
return images.unsqueeze(0), True
|
||||
if images.ndim == 5:
|
||||
return images, False
|
||||
raise ValueError(
|
||||
f"SeedVR2PostProcessing: expected 4-D or 5-D IMAGE tensor, got shape {tuple(images.shape)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _restore_reference_batch_time(decoded, reference):
|
||||
if decoded.shape[0] != 1:
|
||||
return decoded
|
||||
ref_b, ref_t = reference.shape[:2]
|
||||
if ref_b < 1 or decoded.shape[1] % ref_b != 0:
|
||||
return decoded
|
||||
decoded_t = decoded.shape[1] // ref_b
|
||||
if decoded_t < ref_t:
|
||||
return decoded
|
||||
return decoded.reshape(ref_b, decoded_t, decoded.shape[2], decoded.shape[3], decoded.shape[4])
|
||||
|
||||
@staticmethod
|
||||
def _to_seedvr2_raw(images):
|
||||
return images.mul(2.0).sub(1.0)
|
||||
|
||||
@staticmethod
|
||||
def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn):
|
||||
color_device = comfy.model_management.vae_device()
|
||||
decoded_flat = decoded_flat.to(device=color_device)
|
||||
reference_flat = reference_flat.to(device=color_device)
|
||||
output = transfer_fn(decoded_flat, reference_flat)
|
||||
return output.to(device=output_device)
|
||||
|
||||
@staticmethod
|
||||
def _lab_color_transfer_on_vae_device(decoded_flat, reference_flat, output_device):
|
||||
color_device = comfy.model_management.vae_device()
|
||||
result = None
|
||||
for start in range(decoded_flat.shape[0]):
|
||||
decoded_frame = decoded_flat[start:start + 1].to(device=color_device).clone()
|
||||
reference_frame = reference_flat[start:start + 1].to(device=color_device).clone()
|
||||
output = lab_color_transfer(decoded_frame, reference_frame).to(device=output_device)
|
||||
if result is None:
|
||||
result = torch.empty(
|
||||
(decoded_flat.shape[0],) + tuple(output.shape[1:]),
|
||||
device=output_device,
|
||||
dtype=output.dtype,
|
||||
)
|
||||
result[start:start + 1].copy_(output)
|
||||
if result is None:
|
||||
raise ValueError("SeedVR2PostProcessing: LAB color correction requires at least one frame.")
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method):
|
||||
chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method)
|
||||
while True:
|
||||
next_chunk_size = None
|
||||
try:
|
||||
return cls._run_color_transfer_chunks(
|
||||
decoded_flat, reference_flat, output_device, color_correction_method, chunk_size,
|
||||
)
|
||||
except Exception as e:
|
||||
comfy.model_management.raise_non_oom(e)
|
||||
if chunk_size <= 1:
|
||||
raise RuntimeError(
|
||||
"SeedVR2PostProcessing: color correction OOM at one frame; "
|
||||
f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}."
|
||||
) from e
|
||||
next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
|
||||
|
||||
comfy.model_management.soft_empty_cache()
|
||||
chunk_size = next_chunk_size
|
||||
|
||||
@classmethod
|
||||
def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size):
|
||||
result = None
|
||||
for start in range(0, decoded_flat.shape[0], chunk_size):
|
||||
end = min(start + chunk_size, decoded_flat.shape[0])
|
||||
decoded_chunk = decoded_flat[start:end]
|
||||
reference_chunk = reference_flat[start:end]
|
||||
if color_correction_method == "lab":
|
||||
output = cls._lab_color_transfer_on_vae_device(decoded_chunk, reference_chunk, output_device)
|
||||
elif color_correction_method == "wavelet":
|
||||
output = cls._color_transfer_on_vae_device(
|
||||
decoded_chunk, reference_chunk, output_device, wavelet_color_transfer,
|
||||
)
|
||||
else:
|
||||
output = cls._color_transfer_on_vae_device(
|
||||
decoded_chunk, reference_chunk, output_device, adain_color_transfer,
|
||||
)
|
||||
if result is None:
|
||||
result = torch.empty(
|
||||
(decoded_flat.shape[0],) + tuple(output.shape[1:]),
|
||||
device=output_device,
|
||||
dtype=output.dtype,
|
||||
)
|
||||
result[start:end].copy_(output)
|
||||
if result is None:
|
||||
raise ValueError("SeedVR2PostProcessing: color correction requires at least one frame.")
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _estimate_color_correction_chunk_size(cls, decoded_flat, color_correction_method):
|
||||
multiplier = cls._color_correction_memory_multiplier(color_correction_method)
|
||||
frames = decoded_flat.shape[0]
|
||||
_, channels, height, width = decoded_flat.shape
|
||||
dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR)
|
||||
bytes_per_frame = height * width * channels * dtype_bytes * multiplier
|
||||
if bytes_per_frame <= 0:
|
||||
return frames
|
||||
color_device = comfy.model_management.vae_device()
|
||||
free_memory = comfy.model_management.get_free_memory(color_device)
|
||||
chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame)
|
||||
return max(1, min(frames, chunk_size))
|
||||
|
||||
@staticmethod
|
||||
def _color_correction_memory_multiplier(color_correction_method):
|
||||
if color_correction_method == "lab":
|
||||
return SEEDVR2_LAB_SCALE_MULTIPLIER
|
||||
if color_correction_method == "wavelet":
|
||||
return SEEDVR2_WAVELET_SCALE_MULTIPLIER
|
||||
if color_correction_method == "adain":
|
||||
return SEEDVR2_ADAIN_SCALE_MULTIPLIER
|
||||
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
|
||||
|
||||
@staticmethod
|
||||
def _resize_reference(reference, height, width):
|
||||
if reference.shape[2] == height and reference.shape[3] == width:
|
||||
return reference
|
||||
b, t = reference.shape[:2]
|
||||
reference_flat = rearrange(reference, "b t h w c -> (b t) c h w")
|
||||
resized = TVF.resize(
|
||||
reference_flat,
|
||||
size=(height, width),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"),
|
||||
)
|
||||
return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t)
|
||||
|
||||
|
||||
class SeedVR2Conditioning(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedVR2Conditioning",
|
||||
display_name="Apply SeedVR2 Conditioning",
|
||||
category="conditioning",
|
||||
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
|
||||
search_aliases=["seedvr2", "upscale", "conditioning"],
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The SeedVR2 model."),
|
||||
io.Latent.Input("vae_conditioning", display_name="latent"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="model", tooltip="The SeedVR2 model, passed through."),
|
||||
io.Conditioning.Output(display_name="positive", tooltip="The positive conditioning for sampling."),
|
||||
io.Conditioning.Output(display_name="negative", tooltip="The negative conditioning for sampling."),
|
||||
io.Latent.Output(display_name="latent", tooltip="The latent to denoise."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, vae_conditioning) -> io.NodeOutput:
|
||||
|
||||
vae_conditioning = vae_conditioning["samples"]
|
||||
if vae_conditioning.ndim != 5:
|
||||
raise ValueError(
|
||||
"SeedVR2Conditioning expects a 5-D VAE latent in Comfy "
|
||||
f"channel-first layout; got shape {tuple(vae_conditioning.shape)}."
|
||||
)
|
||||
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS:
|
||||
raise ValueError(
|
||||
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
|
||||
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
|
||||
f"got channel-last shape {tuple(vae_conditioning.shape)}."
|
||||
)
|
||||
vae_conditioning = vae_conditioning.movedim(1, -1).contiguous()
|
||||
model_patcher = model
|
||||
model = _resolve_seedvr2_diffusion_model(model_patcher)
|
||||
pos_cond = model.positive_conditioning
|
||||
neg_cond = model.negative_conditioning
|
||||
|
||||
# Fail-loud guard against silently-wrong output when a
|
||||
# DiT-only ``.safetensors`` (no ``positive_conditioning`` /
|
||||
# ``negative_conditioning`` keys) is loaded via ``UNETLoader``.
|
||||
# ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see
|
||||
# ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)``
|
||||
# leaves them at zero when the keys are absent. Detect that state
|
||||
# here rather than at ``BaseModel.extra_conds`` (per sampling step,
|
||||
# wasteful) or at the resolver helper (mixes structural shape with
|
||||
# semantic content). Both buffers must be checked together — partial
|
||||
# bake regressions could populate one but not the other.
|
||||
if (
|
||||
pos_cond.float().abs().sum().item() == 0
|
||||
and neg_cond.float().abs().sum().item() == 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning "
|
||||
f"and negative_conditioning buffers are zero-valued — model "
|
||||
f"file appears to be a DiT-only export missing "
|
||||
f"the SeedVR2 conditioning tensors. "
|
||||
f"Re-bake the file with ``positive_conditioning`` (58, 5120) "
|
||||
f"and ``negative_conditioning`` (64, 5120) keys at top level, "
|
||||
f"or load via CheckpointLoaderSimple from a bundled "
|
||||
f"checkpoint."
|
||||
)
|
||||
|
||||
_apply_rope_freqs_float32_cast(model)
|
||||
|
||||
condition = torch.stack([get_conditions(c, c) for c in vae_conditioning])
|
||||
condition = condition.movedim(-1, 1)
|
||||
latent = vae_conditioning.movedim(-1, 1)
|
||||
|
||||
latent = rearrange(latent, "b c t h w -> b (c t) h w")
|
||||
condition = rearrange(condition, "b c t h w -> b (c t) h w")
|
||||
|
||||
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
|
||||
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
|
||||
|
||||
return io.NodeOutput(model_patcher, positive, negative, {"samples": latent})
|
||||
|
||||
def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int,
|
||||
t_end: int, channels: int) -> torch.Tensor:
|
||||
"""Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse."""
|
||||
B, CT, H, W = tensor_4d.shape
|
||||
if CT % channels != 0:
|
||||
raise ValueError(
|
||||
f"_slice_collapsed_4d_along_t: collapsed channel dim {CT} is not "
|
||||
f"divisible by channels={channels}; tensor shape {tuple(tensor_4d.shape)}."
|
||||
)
|
||||
T = CT // channels
|
||||
if not (0 <= t_start < t_end <= T):
|
||||
raise ValueError(
|
||||
f"_slice_collapsed_4d_along_t: slice [{t_start}:{t_end}] out of "
|
||||
f"range for T={T}."
|
||||
)
|
||||
new_T = t_end - t_start
|
||||
sliced = tensor_4d.reshape(B, channels, T, H, W)[:, :, t_start:t_end, :, :].contiguous()
|
||||
return sliced.reshape(B, channels * new_T, H, W)
|
||||
|
||||
|
||||
def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int):
|
||||
"""Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated."""
|
||||
new_list = []
|
||||
for entry in cond_list:
|
||||
text_cond, options = entry[0], entry[1]
|
||||
if "condition" not in options:
|
||||
new_list.append(entry)
|
||||
continue
|
||||
new_options = options.copy()
|
||||
new_options["condition"] = _slice_collapsed_4d_along_t(
|
||||
new_options["condition"], t_start, t_end,
|
||||
SEEDVR2_COND_CHANNELS,
|
||||
)
|
||||
new_list.append([text_cond, new_options])
|
||||
return new_list
|
||||
|
||||
|
||||
def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor,
|
||||
samples_4d: torch.Tensor,
|
||||
t_start: int,
|
||||
t_end: int):
|
||||
"""Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand."""
|
||||
if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]:
|
||||
return _slice_collapsed_4d_along_t(
|
||||
noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS,
|
||||
)
|
||||
return noise_mask
|
||||
|
||||
|
||||
def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor:
|
||||
"""Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D."""
|
||||
if len(chunks_4d) == 0:
|
||||
raise ValueError("_concat_chunks_along_t: empty chunk list.")
|
||||
fives = []
|
||||
for ch in chunks_4d:
|
||||
B, CT, H, W = ch.shape
|
||||
if CT % channels != 0:
|
||||
raise ValueError(
|
||||
f"_concat_chunks_along_t: chunk shape {tuple(ch.shape)} "
|
||||
f"channel dim {CT} not divisible by channels={channels}."
|
||||
)
|
||||
T = CT // channels
|
||||
fives.append(ch.reshape(B, channels, T, H, W))
|
||||
cat = torch.cat(fives, dim=2).contiguous()
|
||||
B, C, T_total, H, W = cat.shape
|
||||
return cat.reshape(B, C * T_total, H, W)
|
||||
|
||||
|
||||
def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor:
|
||||
"""1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``):
|
||||
Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3``
|
||||
(dead-band would collapse a tiny transition). Window shape matched to the reference
|
||||
overlapping-frame blend for parity; caller broadcasts across ``(B, C, T_overlap, H, W)``.
|
||||
"""
|
||||
if overlap < 1:
|
||||
raise ValueError(
|
||||
f"_hann_blend_weights_1d: overlap must be >= 1; got {overlap}."
|
||||
)
|
||||
if overlap >= 3:
|
||||
t = torch.linspace(0.0, 1.0, steps=overlap, device=device, dtype=dtype)
|
||||
blend_start = 1.0 / 3.0
|
||||
blend_end = 2.0 / 3.0
|
||||
u = ((t - blend_start) / (blend_end - blend_start)).clamp(0.0, 1.0)
|
||||
return 0.5 + 0.5 * torch.cos(torch.pi * u)
|
||||
return torch.linspace(1.0, 0.0, steps=overlap, device=device, dtype=dtype)
|
||||
|
||||
|
||||
def _blend_overlap_region(prev_tail_5d: torch.Tensor,
|
||||
cur_head_5d: torch.Tensor) -> torch.Tensor:
|
||||
"""Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device)."""
|
||||
if prev_tail_5d.shape != cur_head_5d.shape:
|
||||
raise ValueError(
|
||||
f"_blend_overlap_region: shape mismatch "
|
||||
f"prev {tuple(prev_tail_5d.shape)} vs "
|
||||
f"cur {tuple(cur_head_5d.shape)}."
|
||||
)
|
||||
overlap = int(prev_tail_5d.shape[2])
|
||||
w_prev_1d = _hann_blend_weights_1d(
|
||||
overlap, prev_tail_5d.device, prev_tail_5d.dtype,
|
||||
)
|
||||
# Reshape to (1, 1, overlap, 1, 1) for broadcast across B, C, H, W.
|
||||
w_prev = w_prev_1d.view(1, 1, overlap, 1, 1)
|
||||
w_cur = 1.0 - w_prev
|
||||
return prev_tail_5d * w_prev + cur_head_5d * w_cur
|
||||
|
||||
|
||||
def _concat_chunks_with_overlap_blend(chunk_specs, channels: int,
|
||||
overlap_latent: int) -> torch.Tensor:
|
||||
"""Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk."""
|
||||
if len(chunk_specs) == 0:
|
||||
raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.")
|
||||
if overlap_latent < 0:
|
||||
raise ValueError(
|
||||
f"_concat_chunks_with_overlap_blend: overlap_latent must be "
|
||||
f">= 0; got {overlap_latent}."
|
||||
)
|
||||
|
||||
# Validate channel divisibility once and capture per-chunk T.
|
||||
chunk_5d = []
|
||||
for t_start, t_end, ch in chunk_specs:
|
||||
B, CT, H, W = ch.shape
|
||||
if CT % channels != 0:
|
||||
raise ValueError(
|
||||
f"_concat_chunks_with_overlap_blend: chunk shape "
|
||||
f"{tuple(ch.shape)} channel dim {CT} not divisible "
|
||||
f"by channels={channels}."
|
||||
)
|
||||
T = CT // channels
|
||||
if t_end - t_start != T:
|
||||
raise ValueError(
|
||||
f"_concat_chunks_with_overlap_blend: chunk T={T} mismatches "
|
||||
f"declared range [{t_start}:{t_end}]."
|
||||
)
|
||||
chunk_5d.append((t_start, t_end, ch.reshape(B, channels, T, H, W)))
|
||||
|
||||
if overlap_latent == 0:
|
||||
# Fast path: pure concat in the caller-provided chunk order.
|
||||
return _concat_chunks_along_t(
|
||||
[c.reshape(c.shape[0], channels * c.shape[2], c.shape[3], c.shape[4])
|
||||
for _, _, c in chunk_5d],
|
||||
channels,
|
||||
)
|
||||
|
||||
T_total = max(t_end for _, t_end, _ in chunk_5d)
|
||||
first_5d = chunk_5d[0][2]
|
||||
B = first_5d.shape[0]
|
||||
H = first_5d.shape[3]
|
||||
W = first_5d.shape[4]
|
||||
result = torch.empty(
|
||||
(B, channels, T_total, H, W),
|
||||
device=first_5d.device, dtype=first_5d.dtype,
|
||||
)
|
||||
filled_until = 0
|
||||
for i, (cs, ce, ct_5d) in enumerate(chunk_5d):
|
||||
chunk_T = int(ct_5d.shape[2])
|
||||
if i == 0:
|
||||
result[:, :, cs:ce, :, :] = ct_5d
|
||||
filled_until = ce
|
||||
continue
|
||||
# Overlap region width is bounded by both the previous fill
|
||||
# frontier and the current chunk's actual length (for runt
|
||||
# final chunks shorter than the configured overlap).
|
||||
overlap_len = min(filled_until - cs, chunk_T)
|
||||
if overlap_len > 0:
|
||||
prev_tail = result[:, :, cs:cs + overlap_len, :, :].contiguous()
|
||||
cur_head = ct_5d[:, :, :overlap_len, :, :].contiguous()
|
||||
blended = _blend_overlap_region(prev_tail, cur_head)
|
||||
result[:, :, cs:cs + overlap_len, :, :] = blended
|
||||
tail_start = cs + overlap_len
|
||||
tail_end = ce
|
||||
if tail_end > tail_start:
|
||||
result[:, :, tail_start:tail_end, :, :] = (
|
||||
ct_5d[:, :, overlap_len:, :, :]
|
||||
)
|
||||
else:
|
||||
# Disjoint chunks (overlap_latent set but this pair did not
|
||||
# actually overlap, e.g. step_latent equal to chunk_latent
|
||||
# in a degenerate config). Treat as concat.
|
||||
result[:, :, cs:ce, :, :] = ct_5d
|
||||
filled_until = ce
|
||||
|
||||
return result.contiguous().reshape(B, channels * T_total, H, W)
|
||||
|
||||
|
||||
def _run_standard_sample(model, seed: int, steps: int, cfg: float,
|
||||
sampler_name: str, scheduler: str,
|
||||
positive, negative, latent: dict,
|
||||
denoise: float) -> dict:
|
||||
"""Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk."""
|
||||
samples_in = latent["samples"]
|
||||
samples_in = comfy.sample.fix_empty_latent_channels(
|
||||
model, samples_in, latent.get("downscale_ratio_spacial", None),
|
||||
)
|
||||
batch_inds = latent.get("batch_index", None)
|
||||
noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds)
|
||||
noise_mask = latent.get("noise_mask", None)
|
||||
samples = comfy.sample.sample(
|
||||
model, noise, steps, cfg, sampler_name, scheduler,
|
||||
positive, negative, samples_in,
|
||||
denoise=denoise, noise_mask=noise_mask, seed=seed,
|
||||
)
|
||||
out = latent.copy()
|
||||
out.pop("downscale_ratio_spacial", None)
|
||||
out["samples"] = samples
|
||||
return out
|
||||
|
||||
|
||||
class SeedVR2ProgressiveSampler(io.ComfyNode):
|
||||
"""Sequential temporal chunking sampler for SeedVR2 native.
|
||||
|
||||
Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that
|
||||
OOM on long sequences. The latent enters the sampler in SeedVR2's
|
||||
collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning``
|
||||
at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that
|
||||
tensor along the temporal axis, runs the configured inner sampler
|
||||
sequentially per chunk against the standard ``comfy.sample.sample``
|
||||
entry point, and concatenates per-chunk outputs back into a single
|
||||
``(B, 16*T_total, H, W)`` latent.
|
||||
|
||||
``frames_per_chunk`` is expressed in pixel-frame units to match the
|
||||
SeedVR2 4n+1 constraint enforced upstream by ``cut_videos`` and the
|
||||
VAE's ``temporal_downsample_factor=4``. A pixel chunk size ``F``
|
||||
maps to ``(F - 1) // 4 + 1`` latent-frame chunks.
|
||||
|
||||
Determinism contract: a single noise tensor is generated once from
|
||||
the user seed and sliced per chunk (rather than re-seeding each
|
||||
chunk), so a workflow that fits in a single chunk produces output
|
||||
identical to a workflow that fits in N chunks at the same seed,
|
||||
modulo the inherent T-axis chunk-boundary independence of the model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedVR2ProgressiveSampler",
|
||||
display_name="Sample SeedVR2 (Progressive)",
|
||||
category="sampling",
|
||||
description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.",
|
||||
search_aliases=["seedvr2", "upscale", "video upscale", "sampler", "chunk"],
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model used for denoising the input latent."),
|
||||
io.Int.Input("seed", default=0, min=0,
|
||||
max=0xffffffffffffffff,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise."),
|
||||
io.Int.Input("steps", default=20, min=1, max=10000,
|
||||
tooltip="The number of steps used in the denoising process."),
|
||||
io.Float.Input("cfg", default=1.0, min=0.0, max=100.0,
|
||||
step=0.1, round=0.01,
|
||||
tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."),
|
||||
io.Combo.Input("sampler_name",
|
||||
options=comfy.samplers.SAMPLER_NAMES,
|
||||
tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."),
|
||||
io.Combo.Input("scheduler",
|
||||
options=comfy.samplers.SCHEDULER_NAMES,
|
||||
tooltip="The scheduler controls how noise is gradually removed to form the image."),
|
||||
io.Conditioning.Input("positive",
|
||||
tooltip="The conditioning describing the attributes you want to include in the image."),
|
||||
io.Conditioning.Input("negative",
|
||||
tooltip="The conditioning describing the attributes you want to exclude from the image."),
|
||||
io.Latent.Input("latent",
|
||||
tooltip="The latent image to denoise."),
|
||||
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0,
|
||||
step=0.01,
|
||||
tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."),
|
||||
io.Int.Input("frames_per_chunk", default=21, min=1,
|
||||
max=16384, step=4,
|
||||
tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."),
|
||||
io.Int.Input("temporal_overlap", default=0, min=0,
|
||||
max=16384,
|
||||
tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."),
|
||||
io.Combo.Input("chunking_mode",
|
||||
options=["manual", "auto"],
|
||||
default="manual",
|
||||
tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."),
|
||||
],
|
||||
outputs=[io.Latent.Output(display_name="latent", tooltip="The upscaled latent.")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, seed, steps, cfg, sampler_name, scheduler,
|
||||
positive, negative, latent, denoise,
|
||||
frames_per_chunk, temporal_overlap,
|
||||
chunking_mode="manual") -> io.NodeOutput:
|
||||
# 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline
|
||||
# requires pixel-frame counts of the form 4n+1 (1, 5, 9, 13, ...),
|
||||
# imposed at ``cut_videos`` upstream and propagated through the VAE's
|
||||
# temporal_downsample_factor=4. Reject violations explicitly before
|
||||
# any model invocation; a silent rounding would mis-align chunk
|
||||
# boundaries with the 4n+1 lattice.
|
||||
if frames_per_chunk < 1 or (frames_per_chunk - 1) % 4 != 0:
|
||||
raise ValueError(
|
||||
f"SeedVR2ProgressiveSampler: frames_per_chunk must be a "
|
||||
f"4n+1 pixel-frame count (1, 5, 9, 13, 17, 21, ...); "
|
||||
f"got {frames_per_chunk}."
|
||||
)
|
||||
|
||||
samples_4d = latent["samples"]
|
||||
if torch.count_nonzero(samples_4d) == 0:
|
||||
raise ValueError(
|
||||
"SeedVR2ProgressiveSampler: input latent is empty (all zeros). "
|
||||
"SeedVR2 is an upscaler; connect an encoded latent from "
|
||||
"'Apply SeedVR2 conditioning' rather than an empty latent."
|
||||
)
|
||||
samples_4d = comfy.sample.fix_empty_latent_channels(
|
||||
model, samples_4d,
|
||||
latent.get("downscale_ratio_spacial", None),
|
||||
)
|
||||
if samples_4d.ndim != 4:
|
||||
raise ValueError(
|
||||
f"SeedVR2ProgressiveSampler: expected 4D collapsed latent "
|
||||
f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}."
|
||||
)
|
||||
B, CT, H, W = samples_4d.shape
|
||||
if CT % SEEDVR2_LATENT_CHANNELS != 0:
|
||||
raise ValueError(
|
||||
f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is "
|
||||
f"not divisible by SeedVR2 latent channels "
|
||||
f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be "
|
||||
f"SeedVR2-shaped."
|
||||
)
|
||||
T_latent = CT // SEEDVR2_LATENT_CHANNELS
|
||||
T_pixel = 4 * (T_latent - 1) + 1
|
||||
|
||||
if chunking_mode not in ("manual", "auto"):
|
||||
raise ValueError(
|
||||
f"SeedVR2ProgressiveSampler: chunking_mode must be "
|
||||
f"'manual' or 'auto'; got {chunking_mode!r}."
|
||||
)
|
||||
|
||||
if chunking_mode == "auto":
|
||||
free_memory = comfy.model_management.get_free_memory(model.load_device)
|
||||
seed_frames_per_chunk = _seedvr2_vram_seed_frames_per_chunk(
|
||||
free_memory, T_pixel,
|
||||
)
|
||||
logging.info(
|
||||
"SeedVR2ProgressiveSampler auto: free=%.2fGB -> seeding "
|
||||
"frames_per_chunk=%s (4n+1; T_pixel=%s).",
|
||||
free_memory / (1024 ** 3), seed_frames_per_chunk, T_pixel,
|
||||
)
|
||||
attempts = _seedvr2_auto_chunk_attempts(
|
||||
T_latent, T_pixel, seed_frames_per_chunk,
|
||||
)
|
||||
for i, attempt_frames_per_chunk in enumerate(attempts):
|
||||
retry = False
|
||||
try:
|
||||
return cls.execute(
|
||||
model=model, seed=seed, steps=steps, cfg=cfg,
|
||||
sampler_name=sampler_name, scheduler=scheduler,
|
||||
positive=positive, negative=negative,
|
||||
latent=latent, denoise=denoise,
|
||||
frames_per_chunk=attempt_frames_per_chunk,
|
||||
temporal_overlap=temporal_overlap,
|
||||
chunking_mode="manual",
|
||||
)
|
||||
except Exception as e:
|
||||
comfy.model_management.raise_non_oom(e)
|
||||
if i == len(attempts) - 1:
|
||||
raise RuntimeError(
|
||||
"SeedVR2ProgressiveSampler: exhausted auto "
|
||||
"chunking attempts after OOM. Tried "
|
||||
f"frames_per_chunk values {attempts}."
|
||||
) from e
|
||||
retry = True
|
||||
|
||||
if retry:
|
||||
logging.warning(
|
||||
"SeedVR2ProgressiveSampler auto chunking OOM at "
|
||||
"frames_per_chunk=%s; retrying with "
|
||||
"frames_per_chunk=%s.",
|
||||
attempt_frames_per_chunk, attempts[i + 1],
|
||||
)
|
||||
comfy.model_management.soft_empty_cache()
|
||||
|
||||
# Short-circuit: total fits in one chunk -> standard path with no
|
||||
# chunking overhead. Output of this branch is byte-identical to the
|
||||
# built-in KSampler given the same (model, seed, steps, cfg,
|
||||
# sampler_name, scheduler, positive, negative, latent,
|
||||
# denoise) tuple.
|
||||
if T_pixel <= frames_per_chunk:
|
||||
return io.NodeOutput(_run_standard_sample(
|
||||
model, seed, steps, cfg, sampler_name, scheduler,
|
||||
positive, negative, latent, denoise,
|
||||
))
|
||||
|
||||
# Map pixel chunk -> latent chunk. Each chunk's latent length is
|
||||
# at most ``chunk_latent``; the final chunk may be a runt that
|
||||
# is automatically 4n+1-aligned in the pixel domain by the
|
||||
# T_pixel = 4*(T_latent-1) + 1 mapping (every positive integer
|
||||
# T_latent corresponds to a valid 4n+1 pixel count).
|
||||
chunk_latent = (frames_per_chunk - 1) // 4 + 1
|
||||
|
||||
# ``temporal_overlap`` is exposed in latent-frame units, but users
|
||||
# do not know the derived latent chunk length. Treat oversized
|
||||
# values as "maximum valid overlap" while preserving a strictly
|
||||
# positive chunk-loop stride.
|
||||
if temporal_overlap < 0:
|
||||
raise ValueError(
|
||||
f"SeedVR2ProgressiveSampler: temporal_overlap must be >= 0; "
|
||||
f"got {temporal_overlap}."
|
||||
)
|
||||
temporal_overlap = min(temporal_overlap, chunk_latent - 1)
|
||||
step_latent = chunk_latent - temporal_overlap
|
||||
|
||||
# Generate full noise once from the user seed, then slice along T
|
||||
# per chunk. Using one global noise tensor (rather than re-seeding
|
||||
# per chunk) preserves seed-determinism across chunk-count
|
||||
# variations: the same (seed, total T_latent) always produces the
|
||||
# same noise samples regardless of how the work is partitioned.
|
||||
batch_inds = latent.get("batch_index", None)
|
||||
noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds)
|
||||
|
||||
noise_mask = latent.get("noise_mask", None)
|
||||
|
||||
# Build the flat list of chunk ranges first so the chunking
|
||||
# geometry is fully known before any sample call.
|
||||
chunk_ranges = []
|
||||
for chunk_start in range(0, T_latent, step_latent):
|
||||
chunk_end = min(chunk_start + chunk_latent, T_latent)
|
||||
if chunk_start >= chunk_end:
|
||||
# The final iteration of a stride that lands exactly on
|
||||
# T_latent produces a zero-length chunk; skip it.
|
||||
break
|
||||
chunk_ranges.append((chunk_start, chunk_end))
|
||||
if chunk_end >= T_latent:
|
||||
break
|
||||
|
||||
def _sample_one_chunk(chunk_start, chunk_end):
|
||||
samples_chunk = _slice_collapsed_4d_along_t(
|
||||
samples_4d, chunk_start, chunk_end,
|
||||
SEEDVR2_LATENT_CHANNELS,
|
||||
)
|
||||
noise_chunk = _slice_collapsed_4d_along_t(
|
||||
noise_full, chunk_start, chunk_end,
|
||||
SEEDVR2_LATENT_CHANNELS,
|
||||
)
|
||||
positive_chunk = _slice_seedvr2_cond_along_t(
|
||||
positive, chunk_start, chunk_end,
|
||||
)
|
||||
negative_chunk = _slice_seedvr2_cond_along_t(
|
||||
negative, chunk_start, chunk_end,
|
||||
)
|
||||
|
||||
# Per-chunk noise_mask handling: standard masks are passed
|
||||
# through for KSampler expansion; pre-expanded collapsed
|
||||
# masks are sliced.
|
||||
chunk_noise_mask = None
|
||||
if noise_mask is not None:
|
||||
chunk_noise_mask = _slice_seedvr2_noise_mask_along_t(
|
||||
noise_mask, samples_4d, chunk_start, chunk_end,
|
||||
)
|
||||
|
||||
return comfy.sample.sample(
|
||||
model, noise_chunk, steps, cfg, sampler_name, scheduler,
|
||||
positive_chunk, negative_chunk, samples_chunk,
|
||||
denoise=denoise, noise_mask=chunk_noise_mask, seed=seed,
|
||||
)
|
||||
|
||||
chunk_specs = []
|
||||
for chunk_start, chunk_end in chunk_ranges:
|
||||
chunk_samples = _sample_one_chunk(chunk_start, chunk_end)
|
||||
chunk_specs.append((chunk_start, chunk_end, chunk_samples))
|
||||
|
||||
final = _concat_chunks_with_overlap_blend(
|
||||
chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap,
|
||||
)
|
||||
|
||||
out = latent.copy()
|
||||
out.pop("downscale_ratio_spacial", None)
|
||||
out["samples"] = final
|
||||
return io.NodeOutput(out)
|
||||
|
||||
|
||||
class SeedVRExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SeedVR2Conditioning,
|
||||
SeedVR2Preprocess,
|
||||
SeedVR2PostProcessing,
|
||||
SeedVR2ProgressiveSampler,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> SeedVRExtension:
|
||||
return SeedVRExtension()
|
||||
Loading…
Reference in New Issue
Block a user