SeedVR2 native nodes: apply review feedback

- Replace SeedVR2Resize / SeedVR2ResizeAdvanced with a single pad-only node
  SeedVR2Preprocess ("Pre-Process SeedVR2 Input"); resizing is delegated to the
  native Resize Image/Mask node. Input resized_images -> output processed_images.
- SeedVR2PostProcessing consumes the native-resized image directly: inputs images
  + original_resized_images; de-pad the decode to the reference's unpadded
  dimensions; color-correct against the reference pixels; apply alpha with the same
  top-left de-pad crop as RGB. Output display_name "images".
- SeedVR2Conditioning latent input display "latent"; SeedVR2ProgressiveSampler
  latent input/output latent_image -> latent (execute fn + _run_standard_sample).
- Update the SeedVR2 unit tests for the renamed API; remove dead area_resize.
- Strip third-party-node references from comments, a docstring, and a user-facing
  error message.
This commit is contained in:
John Pollock 2026-06-06 13:14:08 -05:00
parent 81f22c335a
commit 1e08e8b724
6 changed files with 62 additions and 185 deletions

View File

@ -6,9 +6,6 @@ Provenance prefixes:
the upstream config/source path it was lifted from. the upstream config/source path it was lifted from.
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature / - unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
ISO / CIE values; cite the standard. ISO / CIE values; cite the standard.
The numz/AInVFX custom node is used only as a behavioral-parity benchmark; it is the
origin of none of these values and appears here nowhere.
""" """
# -------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------

View File

@ -152,18 +152,6 @@ def inter(x_0, x_T, t):
B = lambda t: t / T B = lambda t: t / T
A = lambda t: 1 - (t / T) A = lambda t: 1 - (t / T)
return A(t) * x_0 + B(t) * x_T return A(t) * x_0 + B(t) * x_T
def area_resize(image, max_area):
height, width = image.shape[-2:]
scale = math.sqrt(max_area / (height * width))
resized_height, resized_width = round(height * scale), round(width * scale)
return TVF.resize(
image,
size=(resized_height, resized_width),
interpolation=InterpolationMode.BICUBIC,
)
def div_pad(image, factor): def div_pad(image, factor):
@ -202,12 +190,6 @@ def cut_videos(videos):
assert (videos.size(1) - 1) % (4) == 0 assert (videos.size(1) - 1) % (4) == 0
return videos return videos
def side_resize(image, size):
antialias = not (isinstance(image, torch.Tensor) and image.device.type == 'mps')
resized = TVF.resize(image, size, InterpolationMode.BICUBIC, antialias=antialias)
return resized
def _seedvr2_input_shorter_edge(images, node_name): def _seedvr2_input_shorter_edge(images, node_name):
if images.dim() == 4: if images.dim() == 4:
return min(images.shape[1], images.shape[2]) return min(images.shape[1], images.shape[2])
@ -219,13 +201,12 @@ def _seedvr2_input_shorter_edge(images, node_name):
) )
def _seedvr2_resize_and_pad(images, upscaled_shorter_edge, node_name): def _seedvr2_pad(images, upscaled_shorter_edge, node_name):
if upscaled_shorter_edge < 2: if upscaled_shorter_edge < 2:
raise ValueError( raise ValueError(
f"{node_name}: resolved upscaled_shorter_edge must be at least 2 pixels; " f"{node_name}: input shorter edge must be at least 2 pixels; "
f"got {upscaled_shorter_edge}." f"got {upscaled_shorter_edge}."
) )
original_image = images
if images.shape[-1] > 3: if images.shape[-1] > 3:
images = images[..., :3] images = images[..., :3]
if images.dim() == 4: if images.dim() == 4:
@ -243,8 +224,6 @@ def _seedvr2_resize_and_pad(images, upscaled_shorter_edge, node_name):
images = images.reshape(b * t, c, h, w) images = images.reshape(b * t, c, h, w)
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
images = side_resize(images, upscaled_shorter_edge)
images = clip(images) images = clip(images)
images = div_pad(images, (16, 16)) images = div_pad(images, (16, 16))
_, _, new_h, new_w = images.shape _, _, new_h, new_w = images.shape
@ -253,82 +232,31 @@ def _seedvr2_resize_and_pad(images, upscaled_shorter_edge, node_name):
images = cut_videos(images) images = cut_videos(images)
images_bthwc = rearrange(images, "b t c h w -> b t h w c") images_bthwc = rearrange(images, "b t c h w -> b t h w c")
return io.NodeOutput(images_bthwc, original_image, upscaled_shorter_edge) return io.NodeOutput(images_bthwc)
class SeedVR2Resize(io.ComfyNode): class SeedVR2Preprocess(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="SeedVR2Resize", node_id="SeedVR2Preprocess",
display_name="Resize Image for SeedVR2", display_name="Pre-Process SeedVR2 Input",
category="image/upscaling", category="image/upscaling",
description="Resize an image to a SeedVR2-compatible size by a multiplier.", description="Pad an already-resized image to SeedVR2 model alignment (resize upstream with Resize Image/Mask). Any alpha is dropped for the model; Post-Process SeedVR2 Output re-applies it from the original resized image.",
inputs=[ inputs=[
io.Image.Input("images", tooltip="The image(s) to resize."), io.Image.Input("resized_images", tooltip="The already-resized (SeedVR2-sized) image(s). Any alpha is dropped before padding for the model and re-applied downstream by Post-Process SeedVR2 Output."),
io.Float.Input("multiplier", default=4.0, min=0.01, tooltip="Upscale factor applied to the shorter edge."),
], ],
outputs=[ outputs=[
io.Image.Output("input_pixels"), io.Image.Output("processed_images"),
io.Image.Output("original_image"),
io.Int.Output("upscaled_shorter_edge"),
] ]
) )
@classmethod @classmethod
def execute(cls, images, multiplier=4.0): def execute(cls, resized_images):
if multiplier <= 0: upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess")
raise ValueError( return _seedvr2_pad(
f"SeedVR2Resize: multiplier must be > 0; got {multiplier}." resized_images, upscaled_shorter_edge, "SeedVR2Preprocess",
) )
shorter_edge = _seedvr2_input_shorter_edge(images, "SeedVR2Resize")
upscaled_shorter_edge = int(round(shorter_edge * multiplier))
if upscaled_shorter_edge < 2:
raise ValueError(
"SeedVR2Resize: multiplier resolved upscaled_shorter_edge "
f"to {upscaled_shorter_edge}; use a multiplier that resolves "
"to at least 2 pixels."
)
return _seedvr2_resize_and_pad(
images, upscaled_shorter_edge, "SeedVR2Resize",
)
class SeedVR2ResizeAdvanced(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2ResizeAdvanced",
display_name="Resize Image for SeedVR2 (Advanced)",
category="image/upscaling",
description="Resize an image to an exact shorter-edge size for SeedVR2.",
inputs=[
io.Image.Input("images", tooltip="The image(s) to resize."),
io.Int.Input("shorter_edge", default=1280, min=2, tooltip="Target length of the shorter edge, in pixels."),
],
outputs=[
io.Image.Output("input_pixels"),
io.Image.Output("original_image"),
io.Int.Output("upscaled_shorter_edge"),
]
)
@classmethod
def execute(cls, images, shorter_edge):
return _seedvr2_resize_and_pad(
images, shorter_edge, "SeedVR2ResizeAdvanced",
)
def _edge_guided_alpha_upscale(alpha, out_h, out_w):
a = alpha.float()
extreme_fraction = ((a < 0.1) | (a > 0.9)).float().mean()
if extreme_fraction > 0.9:
up = torch.nn.functional.interpolate(a, size=(out_h, out_w), mode="bilinear", align_corners=False, antialias=True)
up = torch.clamp((up - 0.5) * 4.0 + 0.5, 0.0, 1.0)
else:
up = torch.nn.functional.interpolate(a, size=(out_h, out_w), mode="bicubic", align_corners=False, antialias=True).clamp(0.0, 1.0)
return up
class SeedVR2PostProcessing(io.ComfyNode): class SeedVR2PostProcessing(io.ComfyNode):
@ -338,40 +266,36 @@ class SeedVR2PostProcessing(io.ComfyNode):
node_id="SeedVR2PostProcessing", node_id="SeedVR2PostProcessing",
display_name="Post-Process SeedVR2 Output", display_name="Post-Process SeedVR2 Output",
category="image/upscaling", category="image/upscaling",
description="Align the upscaled output to the original's geometry and optionally color-correct it against the original.", description="Align the upscaled output to the original resized image's geometry (de-pad) and optionally color-correct it against that image.",
inputs=[ inputs=[
io.Image.Input("decoded", tooltip="The decoded upscaled image to color-correct."), io.Image.Input("images", tooltip="The decoded upscaled image to color-correct."),
io.Image.Input("original_image", tooltip="The original image used as the color reference."), io.Image.Input("original_resized_images", tooltip="The original resized (pre-pad) image used as the geometry and color reference."),
io.Int.Input("upscaled_shorter_edge", min=2, force_input=True, tooltip="Shorter-edge size from the resize node."),
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="How to match the output's color to the original. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."), io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="How to match the output's color to the original. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."),
], ],
outputs=[io.Image.Output()], outputs=[io.Image.Output(display_name="images")],
) )
@classmethod @classmethod
def execute(cls, decoded, original_image, upscaled_shorter_edge, color_correction_method): def execute(cls, images, original_resized_images, color_correction_method):
cls._validate_upscaled_shorter_edge(upscaled_shorter_edge)
alpha_input = None alpha_input = None
if original_image.shape[-1] == 4: if original_resized_images.shape[-1] == 4:
alpha_input = original_image[..., 3:4] alpha_input = original_resized_images[..., 3:4]
original_image = original_image[..., :3] original_resized_images = original_resized_images[..., :3]
decoded_5d, decoded_was_4d = cls._as_bthwc(decoded) decoded_5d, decoded_was_4d = cls._as_bthwc(images)
original_5d, _ = cls._as_bthwc(original_image) reference_full, _ = cls._as_bthwc(original_resized_images)
decoded_5d = cls._restore_reference_batch_time(decoded_5d, original_5d) decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full)
b = min(decoded_5d.shape[0], original_5d.shape[0]) b = min(decoded_5d.shape[0], reference_full.shape[0])
t = min(decoded_5d.shape[1], original_5d.shape[1]) t = min(decoded_5d.shape[1], reference_full.shape[1])
reference_h, reference_w = cls._resized_shorter_edge_dims( reference_h = reference_full.shape[2]
original_5d.shape[2], original_5d.shape[3], upscaled_shorter_edge, reference_w = reference_full.shape[3]
)
decoded_5d = decoded_5d[:b, :t, :, :, :] decoded_5d = decoded_5d[:b, :t, :, :, :]
target_h = min(decoded_5d.shape[2], reference_h) target_h = min(decoded_5d.shape[2], reference_h)
target_w = min(decoded_5d.shape[3], reference_w) target_w = min(decoded_5d.shape[3], reference_w)
decoded_5d = decoded_5d[:, :, :target_h, :target_w, :] decoded_5d = decoded_5d[:, :, :target_h, :target_w, :]
if color_correction_method in ("lab", "wavelet", "adain"): if color_correction_method in ("lab", "wavelet", "adain"):
reference_5d = cls._resize_original_reference(original_image, upscaled_shorter_edge) reference_5d = reference_full[:b, :t, :, :, :]
reference_5d = reference_5d[:b, :t, :, :, :]
reference_5d = cls._resize_reference(reference_5d, target_h, target_w) reference_5d = cls._resize_reference(reference_5d, target_h, target_w)
output_device = decoded_5d.device output_device = decoded_5d.device
decoded_raw = cls._to_seedvr2_raw(decoded_5d) decoded_raw = cls._to_seedvr2_raw(decoded_5d)
@ -389,12 +313,9 @@ class SeedVR2PostProcessing(io.ComfyNode):
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
if alpha_input is not None: if alpha_input is not None:
ab, at = output.shape[0], output.shape[1]
alpha_5d, _ = cls._as_bthwc(alpha_input) alpha_5d, _ = cls._as_bthwc(alpha_input)
alpha_flat = rearrange(alpha_5d[:ab, :at], "b t h w c -> (b t) c h w") alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :]
alpha_up = _edge_guided_alpha_upscale(alpha_flat, output.shape[2], output.shape[3]) output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1)
alpha_up = rearrange(alpha_up, "(b t) c h w -> b t h w c", b=ab, t=at)
output = torch.cat([output, alpha_up.to(dtype=output.dtype, device=output.device)], dim=-1)
h2 = output.shape[-3] - (output.shape[-3] % 2) h2 = output.shape[-3] - (output.shape[-3] % 2)
w2 = output.shape[-2] - (output.shape[-2] % 2) w2 = output.shape[-2] - (output.shape[-2] % 2)
output = output[:, :, :h2, :w2, :] output = output[:, :, :h2, :w2, :]
@ -428,28 +349,6 @@ class SeedVR2PostProcessing(io.ComfyNode):
def _to_seedvr2_raw(images): def _to_seedvr2_raw(images):
return images.mul(2.0).sub(1.0) return images.mul(2.0).sub(1.0)
@staticmethod
def _validate_upscaled_shorter_edge(upscaled_shorter_edge):
if not isinstance(upscaled_shorter_edge, int) or upscaled_shorter_edge < 2:
raise ValueError(
"SeedVR2PostProcessing: upscaled_shorter_edge must be an integer "
f"of at least 2 pixels; got {upscaled_shorter_edge!r}."
)
@staticmethod
def _resized_shorter_edge_dims(height, width, upscaled_shorter_edge):
if height <= width:
return upscaled_shorter_edge, int(upscaled_shorter_edge * width / height)
return int(upscaled_shorter_edge * height / width), upscaled_shorter_edge
@classmethod
def _resize_original_reference(cls, original, upscaled_shorter_edge):
original_5d, _ = cls._as_bthwc(original)
b, t = original_5d.shape[:2]
original_flat = rearrange(original_5d, "b t h w c -> (b t) c h w")
resized_flat = side_resize(original_flat, upscaled_shorter_edge).clamp(0.0, 1.0)
return rearrange(resized_flat, "(b t) c h w -> b t h w c", b=b, t=t)
@staticmethod @staticmethod
def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn): def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn):
color_device = comfy.model_management.vae_device() color_device = comfy.model_management.vae_device()
@ -575,7 +474,7 @@ class SeedVR2Conditioning(io.ComfyNode):
description="Build SeedVR2 positive/negative conditioning from a VAE latent.", description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
inputs=[ inputs=[
io.Model.Input("model", tooltip="The SeedVR2 model."), io.Model.Input("model", tooltip="The SeedVR2 model."),
io.Latent.Input("vae_conditioning", display_name="LATENT", tooltip="The VAE-encoded latent to condition on."), io.Latent.Input("vae_conditioning", display_name="latent", tooltip="The VAE-encoded latent to condition on."),
], ],
outputs=[ outputs=[
io.Model.Output(display_name = "model"), io.Model.Output(display_name = "model"),
@ -606,7 +505,7 @@ class SeedVR2Conditioning(io.ComfyNode):
pos_cond = model.positive_conditioning pos_cond = model.positive_conditioning
neg_cond = model.negative_conditioning neg_cond = model.negative_conditioning
# Fail-loud guard against silently-wrong output when a numz-format # Fail-loud guard against silently-wrong output when a
# DiT-only ``.safetensors`` (no ``positive_conditioning`` / # DiT-only ``.safetensors`` (no ``positive_conditioning`` /
# ``negative_conditioning`` keys) is loaded via ``UNETLoader``. # ``negative_conditioning`` keys) is loaded via ``UNETLoader``.
# ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see # ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see
@ -623,7 +522,7 @@ class SeedVR2Conditioning(io.ComfyNode):
raise RuntimeError( raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning " f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning "
f"and negative_conditioning buffers are zero-valued — model " f"and negative_conditioning buffers are zero-valued — model "
f"file appears to be a numz-format DiT-only export missing " f"file appears to be a DiT-only export missing "
f"the SeedVR2 conditioning tensors. " f"the SeedVR2 conditioning tensors. "
f"Re-bake the file with ``positive_conditioning`` (58, 5120) " f"Re-bake the file with ``positive_conditioning`` (58, 5120) "
f"and ``negative_conditioning`` (64, 5120) keys at top level, " f"and ``negative_conditioning`` (64, 5120) keys at top level, "
@ -716,8 +615,8 @@ def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor:
def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor: def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor:
"""1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``): """1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``):
Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3`` Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3``
(dead-band would collapse a tiny transition). Window shape matched to numz ``blend_overlapping_frames`` (dead-band would collapse a tiny transition). Window shape matched to the reference
for parity (reference, not source); caller broadcasts across ``(B, C, T_overlap, H, W)``. overlapping-frame blend for parity; caller broadcasts across ``(B, C, T_overlap, H, W)``.
""" """
if overlap < 1: if overlap < 1:
raise ValueError( raise ValueError(
@ -831,22 +730,22 @@ def _concat_chunks_with_overlap_blend(chunk_specs, channels: int,
def _run_standard_sample(model, seed: int, steps: int, cfg: float, def _run_standard_sample(model, seed: int, steps: int, cfg: float,
sampler_name: str, scheduler: str, sampler_name: str, scheduler: str,
positive, negative, latent_image: dict, positive, negative, latent: dict,
denoise: float) -> dict: denoise: float) -> dict:
"""Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk.""" """Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk."""
samples_in = latent_image["samples"] samples_in = latent["samples"]
samples_in = comfy.sample.fix_empty_latent_channels( samples_in = comfy.sample.fix_empty_latent_channels(
model, samples_in, latent_image.get("downscale_ratio_spacial", None), model, samples_in, latent.get("downscale_ratio_spacial", None),
) )
batch_inds = latent_image.get("batch_index", None) batch_inds = latent.get("batch_index", None)
noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds) noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds)
noise_mask = latent_image.get("noise_mask", None) noise_mask = latent.get("noise_mask", None)
samples = comfy.sample.sample( samples = comfy.sample.sample(
model, noise, steps, cfg, sampler_name, scheduler, model, noise, steps, cfg, sampler_name, scheduler,
positive, negative, samples_in, positive, negative, samples_in,
denoise=denoise, noise_mask=noise_mask, seed=seed, denoise=denoise, noise_mask=noise_mask, seed=seed,
) )
out = latent_image.copy() out = latent.copy()
out.pop("downscale_ratio_spacial", None) out.pop("downscale_ratio_spacial", None)
out["samples"] = samples out["samples"] = samples
return out return out
@ -904,7 +803,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
tooltip="The conditioning describing the attributes you want to include in the image."), tooltip="The conditioning describing the attributes you want to include in the image."),
io.Conditioning.Input("negative", io.Conditioning.Input("negative",
tooltip="The conditioning describing the attributes you want to exclude from the image."), tooltip="The conditioning describing the attributes you want to exclude from the image."),
io.Latent.Input("latent_image", io.Latent.Input("latent",
tooltip="The latent image to denoise."), tooltip="The latent image to denoise."),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, io.Float.Input("denoise", default=1.0, min=0.0, max=1.0,
step=0.01, step=0.01,
@ -920,12 +819,12 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
default="manual", default="manual",
tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."), tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."),
], ],
outputs=[io.Latent.Output()], outputs=[io.Latent.Output(display_name="latent")],
) )
@classmethod @classmethod
def execute(cls, model, seed, steps, cfg, sampler_name, scheduler, def execute(cls, model, seed, steps, cfg, sampler_name, scheduler,
positive, negative, latent_image, denoise, positive, negative, latent, denoise,
frames_per_chunk, temporal_overlap, frames_per_chunk, temporal_overlap,
chunking_mode="manual") -> io.NodeOutput: chunking_mode="manual") -> io.NodeOutput:
# 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline # 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline
@ -941,10 +840,10 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
f"got {frames_per_chunk}." f"got {frames_per_chunk}."
) )
samples_4d = latent_image["samples"] samples_4d = latent["samples"]
samples_4d = comfy.sample.fix_empty_latent_channels( samples_4d = comfy.sample.fix_empty_latent_channels(
model, samples_4d, model, samples_4d,
latent_image.get("downscale_ratio_spacial", None), latent.get("downscale_ratio_spacial", None),
) )
if samples_4d.ndim != 4: if samples_4d.ndim != 4:
raise ValueError( raise ValueError(
@ -979,7 +878,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
model=model, seed=seed, steps=steps, cfg=cfg, model=model, seed=seed, steps=steps, cfg=cfg,
sampler_name=sampler_name, scheduler=scheduler, sampler_name=sampler_name, scheduler=scheduler,
positive=positive, negative=negative, positive=positive, negative=negative,
latent_image=latent_image, denoise=denoise, latent=latent, denoise=denoise,
frames_per_chunk=attempt_frames_per_chunk, frames_per_chunk=attempt_frames_per_chunk,
temporal_overlap=temporal_overlap, temporal_overlap=temporal_overlap,
chunking_mode="manual", chunking_mode="manual",
@ -1006,12 +905,12 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
# Short-circuit: total fits in one chunk -> standard path with no # Short-circuit: total fits in one chunk -> standard path with no
# chunking overhead. Output of this branch is byte-identical to the # chunking overhead. Output of this branch is byte-identical to the
# built-in KSampler given the same (model, seed, steps, cfg, # built-in KSampler given the same (model, seed, steps, cfg,
# sampler_name, scheduler, positive, negative, latent_image, # sampler_name, scheduler, positive, negative, latent,
# denoise) tuple. # denoise) tuple.
if T_pixel <= frames_per_chunk: if T_pixel <= frames_per_chunk:
return io.NodeOutput(_run_standard_sample( return io.NodeOutput(_run_standard_sample(
model, seed, steps, cfg, sampler_name, scheduler, model, seed, steps, cfg, sampler_name, scheduler,
positive, negative, latent_image, denoise, positive, negative, latent, denoise,
)) ))
# Map pixel chunk -> latent chunk. Each chunk's latent length is # Map pixel chunk -> latent chunk. Each chunk's latent length is
@ -1038,10 +937,10 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
# per chunk) preserves seed-determinism across chunk-count # per chunk) preserves seed-determinism across chunk-count
# variations: the same (seed, total T_latent) always produces the # variations: the same (seed, total T_latent) always produces the
# same noise samples regardless of how the work is partitioned. # same noise samples regardless of how the work is partitioned.
batch_inds = latent_image.get("batch_index", None) batch_inds = latent.get("batch_index", None)
noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds) noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds)
noise_mask = latent_image.get("noise_mask", None) noise_mask = latent.get("noise_mask", None)
# Build the flat list of chunk ranges first so the chunking # Build the flat list of chunk ranges first so the chunking
# geometry is fully known before any sample call. # geometry is fully known before any sample call.
@ -1096,7 +995,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap, chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap,
) )
out = latent_image.copy() out = latent.copy()
out.pop("downscale_ratio_spacial", None) out.pop("downscale_ratio_spacial", None)
out["samples"] = final out["samples"] = final
return io.NodeOutput(out) return io.NodeOutput(out)
@ -1107,8 +1006,7 @@ class SeedVRExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ return [
SeedVR2Conditioning, SeedVR2Conditioning,
SeedVR2Resize, SeedVR2Preprocess,
SeedVR2ResizeAdvanced,
SeedVR2PostProcessing, SeedVR2PostProcessing,
SeedVR2ProgressiveSampler, SeedVR2ProgressiveSampler,
] ]

View File

@ -123,7 +123,7 @@ def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
"model", "model",
"vae_conditioning", "vae_conditioning",
] ]
assert schema.inputs[1].display_name == "LATENT" assert schema.inputs[1].display_name == "latent"
assert [output.display_name for output in schema.outputs] == [ assert [output.display_name for output in schema.outputs] == [
"model", "model",
"positive", "positive",

View File

@ -10,21 +10,6 @@ from comfy.cli_args import args as cli_args
if not torch.cuda.is_available(): if not torch.cuda.is_available():
cli_args.cpu = True cli_args.cpu = True
import comfy_extras.nodes_seedvr as nodes_seedvr # noqa: E402
def test_resize_simple_multiplier_resolves_upscaled_shorter_edge():
images = torch.zeros(1, 3, 16, 20, 3)
output = nodes_seedvr.SeedVR2Resize.execute(images, 4.0)
input_pixels, original_image, upscaled_shorter_edge = output.result
assert tuple(input_pixels.shape) == (1, 5, 64, 80, 3)
assert input_pixels.min().item() == 0.0
assert input_pixels.max().item() == 0.0
assert original_image is images
assert upscaled_shorter_edge == 64
def test_seedvr_node_signature_matches_schema(): def test_seedvr_node_signature_matches_schema():
mock_mm = MagicMock() mock_mm = MagicMock()
@ -46,7 +31,7 @@ def test_seedvr_node_signature_matches_schema():
sys.modules.pop("comfy_extras.nodes_seedvr", None) sys.modules.pop("comfy_extras.nodes_seedvr", None)
try: try:
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr") nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
for node_cls in (nodes_seedvr.SeedVR2Resize, nodes_seedvr.SeedVR2ResizeAdvanced): for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler):
schema_ids = [i.id for i in node_cls.define_schema().inputs] schema_ids = [i.id for i in node_cls.define_schema().inputs]
exec_params = [ exec_params = [
p for p in inspect.signature(node_cls.execute).parameters.keys() p for p in inspect.signature(node_cls.execute).parameters.keys()

View File

@ -17,12 +17,9 @@ def _schema_ids(items):
def test_seedvr2_post_processing_schema(): def test_seedvr2_post_processing_schema():
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema() schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
assert _schema_ids(schema.inputs) == ["decoded", "original_image", "upscaled_shorter_edge", "color_correction_method"] assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"]
assert schema.inputs[2].default is None assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"]
assert schema.inputs[2].min == 2 assert schema.inputs[2].default == "lab"
assert schema.inputs[2].force_input is True
assert schema.inputs[3].options == ["lab", "wavelet", "adain", "none"]
assert schema.inputs[3].default == "lab"
assert schema.outputs[0].get_io_type() == "IMAGE" assert schema.outputs[0].get_io_type() == "IMAGE"
@ -53,7 +50,7 @@ def test_seedvr2_post_processing_unknown_color_correction_method_raises():
decoded = torch.zeros(1, 2, 4, 4, 3) decoded = torch.zeros(1, 2, 4, 4, 3)
original = torch.zeros(1, 2, 4, 4, 3) original = torch.zeros(1, 2, 4, 4, 3)
try: try:
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "bogus") nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus")
except ValueError as exc: except ValueError as exc:
assert "color_correction_method" in str(exc) assert "color_correction_method" in str(exc)
else: else:

View File

@ -83,7 +83,7 @@ def test_auto_chunking_walks_two_three_four_chunk_ladder():
out = SeedVR2ProgressiveSampler.execute( out = SeedVR2ProgressiveSampler.execute(
model=None, seed=0, steps=2, cfg=1.0, model=None, seed=0, steps=2, cfg=1.0,
sampler_name="euler", scheduler="simple", sampler_name="euler", scheduler="simple",
positive=pos, negative=neg, latent_image=latent, positive=pos, negative=neg, latent=latent,
denoise=1.0, frames_per_chunk=65, temporal_overlap=0, denoise=1.0, frames_per_chunk=65, temporal_overlap=0,
chunking_mode="auto", chunking_mode="auto",
) )
@ -119,7 +119,7 @@ def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk):
SeedVR2ProgressiveSampler.execute( SeedVR2ProgressiveSampler.execute(
model=None, seed=0, steps=2, cfg=1.0, model=None, seed=0, steps=2, cfg=1.0,
sampler_name="euler", scheduler="simple", sampler_name="euler", scheduler="simple",
positive=pos, negative=neg, latent_image=latent, positive=pos, negative=neg, latent=latent,
denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0, denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0,
) )
assert str(bad_chunk) in str(excinfo.value) assert str(bad_chunk) in str(excinfo.value)