mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 09:49:26 +08:00
SeedVR2 native nodes: apply review feedback
- Replace SeedVR2Resize / SeedVR2ResizeAdvanced with a single pad-only node
SeedVR2Preprocess ("Pre-Process SeedVR2 Input"); resizing is delegated to the
native Resize Image/Mask node. Input resized_images -> output processed_images.
- SeedVR2PostProcessing consumes the native-resized image directly: inputs images
+ original_resized_images; de-pad the decode to the reference's unpadded
dimensions; color-correct against the reference pixels; apply alpha with the same
top-left de-pad crop as RGB. Output display_name "images".
- SeedVR2Conditioning latent input display "latent"; SeedVR2ProgressiveSampler
latent input/output latent_image -> latent (execute fn + _run_standard_sample).
- Update the SeedVR2 unit tests for the renamed API; remove dead area_resize.
- Strip third-party-node references from comments, a docstring, and a user-facing
error message.
This commit is contained in:
parent
81f22c335a
commit
1e08e8b724
@ -6,9 +6,6 @@ Provenance prefixes:
|
|||||||
the upstream config/source path it was lifted from.
|
the upstream config/source path it was lifted from.
|
||||||
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
|
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
|
||||||
ISO / CIE values; cite the standard.
|
ISO / CIE values; cite the standard.
|
||||||
|
|
||||||
The numz/AInVFX custom node is used only as a behavioral-parity benchmark; it is the
|
|
||||||
origin of none of these values and appears here nowhere.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------
|
||||||
|
|||||||
@ -152,18 +152,6 @@ def inter(x_0, x_T, t):
|
|||||||
B = lambda t: t / T
|
B = lambda t: t / T
|
||||||
A = lambda t: 1 - (t / T)
|
A = lambda t: 1 - (t / T)
|
||||||
return A(t) * x_0 + B(t) * x_T
|
return A(t) * x_0 + B(t) * x_T
|
||||||
def area_resize(image, max_area):
|
|
||||||
|
|
||||||
height, width = image.shape[-2:]
|
|
||||||
scale = math.sqrt(max_area / (height * width))
|
|
||||||
|
|
||||||
resized_height, resized_width = round(height * scale), round(width * scale)
|
|
||||||
|
|
||||||
return TVF.resize(
|
|
||||||
image,
|
|
||||||
size=(resized_height, resized_width),
|
|
||||||
interpolation=InterpolationMode.BICUBIC,
|
|
||||||
)
|
|
||||||
|
|
||||||
def div_pad(image, factor):
|
def div_pad(image, factor):
|
||||||
|
|
||||||
@ -202,12 +190,6 @@ def cut_videos(videos):
|
|||||||
assert (videos.size(1) - 1) % (4) == 0
|
assert (videos.size(1) - 1) % (4) == 0
|
||||||
return videos
|
return videos
|
||||||
|
|
||||||
def side_resize(image, size):
|
|
||||||
antialias = not (isinstance(image, torch.Tensor) and image.device.type == 'mps')
|
|
||||||
resized = TVF.resize(image, size, InterpolationMode.BICUBIC, antialias=antialias)
|
|
||||||
return resized
|
|
||||||
|
|
||||||
|
|
||||||
def _seedvr2_input_shorter_edge(images, node_name):
|
def _seedvr2_input_shorter_edge(images, node_name):
|
||||||
if images.dim() == 4:
|
if images.dim() == 4:
|
||||||
return min(images.shape[1], images.shape[2])
|
return min(images.shape[1], images.shape[2])
|
||||||
@ -219,13 +201,12 @@ def _seedvr2_input_shorter_edge(images, node_name):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _seedvr2_resize_and_pad(images, upscaled_shorter_edge, node_name):
|
def _seedvr2_pad(images, upscaled_shorter_edge, node_name):
|
||||||
if upscaled_shorter_edge < 2:
|
if upscaled_shorter_edge < 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{node_name}: resolved upscaled_shorter_edge must be at least 2 pixels; "
|
f"{node_name}: input shorter edge must be at least 2 pixels; "
|
||||||
f"got {upscaled_shorter_edge}."
|
f"got {upscaled_shorter_edge}."
|
||||||
)
|
)
|
||||||
original_image = images
|
|
||||||
if images.shape[-1] > 3:
|
if images.shape[-1] > 3:
|
||||||
images = images[..., :3]
|
images = images[..., :3]
|
||||||
if images.dim() == 4:
|
if images.dim() == 4:
|
||||||
@ -243,8 +224,6 @@ def _seedvr2_resize_and_pad(images, upscaled_shorter_edge, node_name):
|
|||||||
images = images.reshape(b * t, c, h, w)
|
images = images.reshape(b * t, c, h, w)
|
||||||
|
|
||||||
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
||||||
images = side_resize(images, upscaled_shorter_edge)
|
|
||||||
|
|
||||||
images = clip(images)
|
images = clip(images)
|
||||||
images = div_pad(images, (16, 16))
|
images = div_pad(images, (16, 16))
|
||||||
_, _, new_h, new_w = images.shape
|
_, _, new_h, new_w = images.shape
|
||||||
@ -253,84 +232,33 @@ def _seedvr2_resize_and_pad(images, upscaled_shorter_edge, node_name):
|
|||||||
images = cut_videos(images)
|
images = cut_videos(images)
|
||||||
images_bthwc = rearrange(images, "b t c h w -> b t h w c")
|
images_bthwc = rearrange(images, "b t c h w -> b t h w c")
|
||||||
|
|
||||||
return io.NodeOutput(images_bthwc, original_image, upscaled_shorter_edge)
|
return io.NodeOutput(images_bthwc)
|
||||||
|
|
||||||
|
|
||||||
class SeedVR2Resize(io.ComfyNode):
|
class SeedVR2Preprocess(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SeedVR2Resize",
|
node_id="SeedVR2Preprocess",
|
||||||
display_name="Resize Image for SeedVR2",
|
display_name="Pre-Process SeedVR2 Input",
|
||||||
category="image/upscaling",
|
category="image/upscaling",
|
||||||
description="Resize an image to a SeedVR2-compatible size by a multiplier.",
|
description="Pad an already-resized image to SeedVR2 model alignment (resize upstream with Resize Image/Mask). Any alpha is dropped for the model; Post-Process SeedVR2 Output re-applies it from the original resized image.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("images", tooltip="The image(s) to resize."),
|
io.Image.Input("resized_images", tooltip="The already-resized (SeedVR2-sized) image(s). Any alpha is dropped before padding for the model and re-applied downstream by Post-Process SeedVR2 Output."),
|
||||||
io.Float.Input("multiplier", default=4.0, min=0.01, tooltip="Upscale factor applied to the shorter edge."),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output("input_pixels"),
|
io.Image.Output("processed_images"),
|
||||||
io.Image.Output("original_image"),
|
|
||||||
io.Int.Output("upscaled_shorter_edge"),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, images, multiplier=4.0):
|
def execute(cls, resized_images):
|
||||||
if multiplier <= 0:
|
upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess")
|
||||||
raise ValueError(
|
return _seedvr2_pad(
|
||||||
f"SeedVR2Resize: multiplier must be > 0; got {multiplier}."
|
resized_images, upscaled_shorter_edge, "SeedVR2Preprocess",
|
||||||
)
|
|
||||||
shorter_edge = _seedvr2_input_shorter_edge(images, "SeedVR2Resize")
|
|
||||||
upscaled_shorter_edge = int(round(shorter_edge * multiplier))
|
|
||||||
if upscaled_shorter_edge < 2:
|
|
||||||
raise ValueError(
|
|
||||||
"SeedVR2Resize: multiplier resolved upscaled_shorter_edge "
|
|
||||||
f"to {upscaled_shorter_edge}; use a multiplier that resolves "
|
|
||||||
"to at least 2 pixels."
|
|
||||||
)
|
|
||||||
return _seedvr2_resize_and_pad(
|
|
||||||
images, upscaled_shorter_edge, "SeedVR2Resize",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SeedVR2ResizeAdvanced(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return io.Schema(
|
|
||||||
node_id="SeedVR2ResizeAdvanced",
|
|
||||||
display_name="Resize Image for SeedVR2 (Advanced)",
|
|
||||||
category="image/upscaling",
|
|
||||||
description="Resize an image to an exact shorter-edge size for SeedVR2.",
|
|
||||||
inputs=[
|
|
||||||
io.Image.Input("images", tooltip="The image(s) to resize."),
|
|
||||||
io.Int.Input("shorter_edge", default=1280, min=2, tooltip="Target length of the shorter edge, in pixels."),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
io.Image.Output("input_pixels"),
|
|
||||||
io.Image.Output("original_image"),
|
|
||||||
io.Int.Output("upscaled_shorter_edge"),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, images, shorter_edge):
|
|
||||||
return _seedvr2_resize_and_pad(
|
|
||||||
images, shorter_edge, "SeedVR2ResizeAdvanced",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _edge_guided_alpha_upscale(alpha, out_h, out_w):
|
|
||||||
a = alpha.float()
|
|
||||||
extreme_fraction = ((a < 0.1) | (a > 0.9)).float().mean()
|
|
||||||
if extreme_fraction > 0.9:
|
|
||||||
up = torch.nn.functional.interpolate(a, size=(out_h, out_w), mode="bilinear", align_corners=False, antialias=True)
|
|
||||||
up = torch.clamp((up - 0.5) * 4.0 + 0.5, 0.0, 1.0)
|
|
||||||
else:
|
|
||||||
up = torch.nn.functional.interpolate(a, size=(out_h, out_w), mode="bicubic", align_corners=False, antialias=True).clamp(0.0, 1.0)
|
|
||||||
return up
|
|
||||||
|
|
||||||
|
|
||||||
class SeedVR2PostProcessing(io.ComfyNode):
|
class SeedVR2PostProcessing(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -338,40 +266,36 @@ class SeedVR2PostProcessing(io.ComfyNode):
|
|||||||
node_id="SeedVR2PostProcessing",
|
node_id="SeedVR2PostProcessing",
|
||||||
display_name="Post-Process SeedVR2 Output",
|
display_name="Post-Process SeedVR2 Output",
|
||||||
category="image/upscaling",
|
category="image/upscaling",
|
||||||
description="Align the upscaled output to the original's geometry and optionally color-correct it against the original.",
|
description="Align the upscaled output to the original resized image's geometry (de-pad) and optionally color-correct it against that image.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("decoded", tooltip="The decoded upscaled image to color-correct."),
|
io.Image.Input("images", tooltip="The decoded upscaled image to color-correct."),
|
||||||
io.Image.Input("original_image", tooltip="The original image used as the color reference."),
|
io.Image.Input("original_resized_images", tooltip="The original resized (pre-pad) image used as the geometry and color reference."),
|
||||||
io.Int.Input("upscaled_shorter_edge", min=2, force_input=True, tooltip="Shorter-edge size from the resize node."),
|
|
||||||
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="How to match the output's color to the original. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."),
|
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="How to match the output's color to the original. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."),
|
||||||
],
|
],
|
||||||
outputs=[io.Image.Output()],
|
outputs=[io.Image.Output(display_name="images")],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, decoded, original_image, upscaled_shorter_edge, color_correction_method):
|
def execute(cls, images, original_resized_images, color_correction_method):
|
||||||
cls._validate_upscaled_shorter_edge(upscaled_shorter_edge)
|
|
||||||
alpha_input = None
|
alpha_input = None
|
||||||
if original_image.shape[-1] == 4:
|
if original_resized_images.shape[-1] == 4:
|
||||||
alpha_input = original_image[..., 3:4]
|
alpha_input = original_resized_images[..., 3:4]
|
||||||
original_image = original_image[..., :3]
|
original_resized_images = original_resized_images[..., :3]
|
||||||
decoded_5d, decoded_was_4d = cls._as_bthwc(decoded)
|
decoded_5d, decoded_was_4d = cls._as_bthwc(images)
|
||||||
original_5d, _ = cls._as_bthwc(original_image)
|
reference_full, _ = cls._as_bthwc(original_resized_images)
|
||||||
decoded_5d = cls._restore_reference_batch_time(decoded_5d, original_5d)
|
decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full)
|
||||||
|
|
||||||
b = min(decoded_5d.shape[0], original_5d.shape[0])
|
b = min(decoded_5d.shape[0], reference_full.shape[0])
|
||||||
t = min(decoded_5d.shape[1], original_5d.shape[1])
|
t = min(decoded_5d.shape[1], reference_full.shape[1])
|
||||||
reference_h, reference_w = cls._resized_shorter_edge_dims(
|
reference_h = reference_full.shape[2]
|
||||||
original_5d.shape[2], original_5d.shape[3], upscaled_shorter_edge,
|
reference_w = reference_full.shape[3]
|
||||||
)
|
|
||||||
|
|
||||||
decoded_5d = decoded_5d[:b, :t, :, :, :]
|
decoded_5d = decoded_5d[:b, :t, :, :, :]
|
||||||
target_h = min(decoded_5d.shape[2], reference_h)
|
target_h = min(decoded_5d.shape[2], reference_h)
|
||||||
target_w = min(decoded_5d.shape[3], reference_w)
|
target_w = min(decoded_5d.shape[3], reference_w)
|
||||||
decoded_5d = decoded_5d[:, :, :target_h, :target_w, :]
|
decoded_5d = decoded_5d[:, :, :target_h, :target_w, :]
|
||||||
if color_correction_method in ("lab", "wavelet", "adain"):
|
if color_correction_method in ("lab", "wavelet", "adain"):
|
||||||
reference_5d = cls._resize_original_reference(original_image, upscaled_shorter_edge)
|
reference_5d = reference_full[:b, :t, :, :, :]
|
||||||
reference_5d = reference_5d[:b, :t, :, :, :]
|
|
||||||
reference_5d = cls._resize_reference(reference_5d, target_h, target_w)
|
reference_5d = cls._resize_reference(reference_5d, target_h, target_w)
|
||||||
output_device = decoded_5d.device
|
output_device = decoded_5d.device
|
||||||
decoded_raw = cls._to_seedvr2_raw(decoded_5d)
|
decoded_raw = cls._to_seedvr2_raw(decoded_5d)
|
||||||
@ -389,12 +313,9 @@ class SeedVR2PostProcessing(io.ComfyNode):
|
|||||||
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
|
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
|
||||||
|
|
||||||
if alpha_input is not None:
|
if alpha_input is not None:
|
||||||
ab, at = output.shape[0], output.shape[1]
|
|
||||||
alpha_5d, _ = cls._as_bthwc(alpha_input)
|
alpha_5d, _ = cls._as_bthwc(alpha_input)
|
||||||
alpha_flat = rearrange(alpha_5d[:ab, :at], "b t h w c -> (b t) c h w")
|
alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :]
|
||||||
alpha_up = _edge_guided_alpha_upscale(alpha_flat, output.shape[2], output.shape[3])
|
output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1)
|
||||||
alpha_up = rearrange(alpha_up, "(b t) c h w -> b t h w c", b=ab, t=at)
|
|
||||||
output = torch.cat([output, alpha_up.to(dtype=output.dtype, device=output.device)], dim=-1)
|
|
||||||
h2 = output.shape[-3] - (output.shape[-3] % 2)
|
h2 = output.shape[-3] - (output.shape[-3] % 2)
|
||||||
w2 = output.shape[-2] - (output.shape[-2] % 2)
|
w2 = output.shape[-2] - (output.shape[-2] % 2)
|
||||||
output = output[:, :, :h2, :w2, :]
|
output = output[:, :, :h2, :w2, :]
|
||||||
@ -428,28 +349,6 @@ class SeedVR2PostProcessing(io.ComfyNode):
|
|||||||
def _to_seedvr2_raw(images):
|
def _to_seedvr2_raw(images):
|
||||||
return images.mul(2.0).sub(1.0)
|
return images.mul(2.0).sub(1.0)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _validate_upscaled_shorter_edge(upscaled_shorter_edge):
|
|
||||||
if not isinstance(upscaled_shorter_edge, int) or upscaled_shorter_edge < 2:
|
|
||||||
raise ValueError(
|
|
||||||
"SeedVR2PostProcessing: upscaled_shorter_edge must be an integer "
|
|
||||||
f"of at least 2 pixels; got {upscaled_shorter_edge!r}."
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _resized_shorter_edge_dims(height, width, upscaled_shorter_edge):
|
|
||||||
if height <= width:
|
|
||||||
return upscaled_shorter_edge, int(upscaled_shorter_edge * width / height)
|
|
||||||
return int(upscaled_shorter_edge * height / width), upscaled_shorter_edge
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _resize_original_reference(cls, original, upscaled_shorter_edge):
|
|
||||||
original_5d, _ = cls._as_bthwc(original)
|
|
||||||
b, t = original_5d.shape[:2]
|
|
||||||
original_flat = rearrange(original_5d, "b t h w c -> (b t) c h w")
|
|
||||||
resized_flat = side_resize(original_flat, upscaled_shorter_edge).clamp(0.0, 1.0)
|
|
||||||
return rearrange(resized_flat, "(b t) c h w -> b t h w c", b=b, t=t)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn):
|
def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn):
|
||||||
color_device = comfy.model_management.vae_device()
|
color_device = comfy.model_management.vae_device()
|
||||||
@ -575,7 +474,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
|
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model", tooltip="The SeedVR2 model."),
|
io.Model.Input("model", tooltip="The SeedVR2 model."),
|
||||||
io.Latent.Input("vae_conditioning", display_name="LATENT", tooltip="The VAE-encoded latent to condition on."),
|
io.Latent.Input("vae_conditioning", display_name="latent", tooltip="The VAE-encoded latent to condition on."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(display_name = "model"),
|
io.Model.Output(display_name = "model"),
|
||||||
@ -606,7 +505,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
pos_cond = model.positive_conditioning
|
pos_cond = model.positive_conditioning
|
||||||
neg_cond = model.negative_conditioning
|
neg_cond = model.negative_conditioning
|
||||||
|
|
||||||
# Fail-loud guard against silently-wrong output when a numz-format
|
# Fail-loud guard against silently-wrong output when a
|
||||||
# DiT-only ``.safetensors`` (no ``positive_conditioning`` /
|
# DiT-only ``.safetensors`` (no ``positive_conditioning`` /
|
||||||
# ``negative_conditioning`` keys) is loaded via ``UNETLoader``.
|
# ``negative_conditioning`` keys) is loaded via ``UNETLoader``.
|
||||||
# ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see
|
# ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see
|
||||||
@ -623,7 +522,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning "
|
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning "
|
||||||
f"and negative_conditioning buffers are zero-valued — model "
|
f"and negative_conditioning buffers are zero-valued — model "
|
||||||
f"file appears to be a numz-format DiT-only export missing "
|
f"file appears to be a DiT-only export missing "
|
||||||
f"the SeedVR2 conditioning tensors. "
|
f"the SeedVR2 conditioning tensors. "
|
||||||
f"Re-bake the file with ``positive_conditioning`` (58, 5120) "
|
f"Re-bake the file with ``positive_conditioning`` (58, 5120) "
|
||||||
f"and ``negative_conditioning`` (64, 5120) keys at top level, "
|
f"and ``negative_conditioning`` (64, 5120) keys at top level, "
|
||||||
@ -716,8 +615,8 @@ def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor:
|
|||||||
def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor:
|
def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor:
|
||||||
"""1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``):
|
"""1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``):
|
||||||
Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3``
|
Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3``
|
||||||
(dead-band would collapse a tiny transition). Window shape matched to numz ``blend_overlapping_frames``
|
(dead-band would collapse a tiny transition). Window shape matched to the reference
|
||||||
for parity (reference, not source); caller broadcasts across ``(B, C, T_overlap, H, W)``.
|
overlapping-frame blend for parity; caller broadcasts across ``(B, C, T_overlap, H, W)``.
|
||||||
"""
|
"""
|
||||||
if overlap < 1:
|
if overlap < 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -831,22 +730,22 @@ def _concat_chunks_with_overlap_blend(chunk_specs, channels: int,
|
|||||||
|
|
||||||
def _run_standard_sample(model, seed: int, steps: int, cfg: float,
|
def _run_standard_sample(model, seed: int, steps: int, cfg: float,
|
||||||
sampler_name: str, scheduler: str,
|
sampler_name: str, scheduler: str,
|
||||||
positive, negative, latent_image: dict,
|
positive, negative, latent: dict,
|
||||||
denoise: float) -> dict:
|
denoise: float) -> dict:
|
||||||
"""Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk."""
|
"""Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk."""
|
||||||
samples_in = latent_image["samples"]
|
samples_in = latent["samples"]
|
||||||
samples_in = comfy.sample.fix_empty_latent_channels(
|
samples_in = comfy.sample.fix_empty_latent_channels(
|
||||||
model, samples_in, latent_image.get("downscale_ratio_spacial", None),
|
model, samples_in, latent.get("downscale_ratio_spacial", None),
|
||||||
)
|
)
|
||||||
batch_inds = latent_image.get("batch_index", None)
|
batch_inds = latent.get("batch_index", None)
|
||||||
noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds)
|
noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds)
|
||||||
noise_mask = latent_image.get("noise_mask", None)
|
noise_mask = latent.get("noise_mask", None)
|
||||||
samples = comfy.sample.sample(
|
samples = comfy.sample.sample(
|
||||||
model, noise, steps, cfg, sampler_name, scheduler,
|
model, noise, steps, cfg, sampler_name, scheduler,
|
||||||
positive, negative, samples_in,
|
positive, negative, samples_in,
|
||||||
denoise=denoise, noise_mask=noise_mask, seed=seed,
|
denoise=denoise, noise_mask=noise_mask, seed=seed,
|
||||||
)
|
)
|
||||||
out = latent_image.copy()
|
out = latent.copy()
|
||||||
out.pop("downscale_ratio_spacial", None)
|
out.pop("downscale_ratio_spacial", None)
|
||||||
out["samples"] = samples
|
out["samples"] = samples
|
||||||
return out
|
return out
|
||||||
@ -904,7 +803,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
|||||||
tooltip="The conditioning describing the attributes you want to include in the image."),
|
tooltip="The conditioning describing the attributes you want to include in the image."),
|
||||||
io.Conditioning.Input("negative",
|
io.Conditioning.Input("negative",
|
||||||
tooltip="The conditioning describing the attributes you want to exclude from the image."),
|
tooltip="The conditioning describing the attributes you want to exclude from the image."),
|
||||||
io.Latent.Input("latent_image",
|
io.Latent.Input("latent",
|
||||||
tooltip="The latent image to denoise."),
|
tooltip="The latent image to denoise."),
|
||||||
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0,
|
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0,
|
||||||
step=0.01,
|
step=0.01,
|
||||||
@ -920,12 +819,12 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
|||||||
default="manual",
|
default="manual",
|
||||||
tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."),
|
tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."),
|
||||||
],
|
],
|
||||||
outputs=[io.Latent.Output()],
|
outputs=[io.Latent.Output(display_name="latent")],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model, seed, steps, cfg, sampler_name, scheduler,
|
def execute(cls, model, seed, steps, cfg, sampler_name, scheduler,
|
||||||
positive, negative, latent_image, denoise,
|
positive, negative, latent, denoise,
|
||||||
frames_per_chunk, temporal_overlap,
|
frames_per_chunk, temporal_overlap,
|
||||||
chunking_mode="manual") -> io.NodeOutput:
|
chunking_mode="manual") -> io.NodeOutput:
|
||||||
# 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline
|
# 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline
|
||||||
@ -941,10 +840,10 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
|||||||
f"got {frames_per_chunk}."
|
f"got {frames_per_chunk}."
|
||||||
)
|
)
|
||||||
|
|
||||||
samples_4d = latent_image["samples"]
|
samples_4d = latent["samples"]
|
||||||
samples_4d = comfy.sample.fix_empty_latent_channels(
|
samples_4d = comfy.sample.fix_empty_latent_channels(
|
||||||
model, samples_4d,
|
model, samples_4d,
|
||||||
latent_image.get("downscale_ratio_spacial", None),
|
latent.get("downscale_ratio_spacial", None),
|
||||||
)
|
)
|
||||||
if samples_4d.ndim != 4:
|
if samples_4d.ndim != 4:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -979,7 +878,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
|||||||
model=model, seed=seed, steps=steps, cfg=cfg,
|
model=model, seed=seed, steps=steps, cfg=cfg,
|
||||||
sampler_name=sampler_name, scheduler=scheduler,
|
sampler_name=sampler_name, scheduler=scheduler,
|
||||||
positive=positive, negative=negative,
|
positive=positive, negative=negative,
|
||||||
latent_image=latent_image, denoise=denoise,
|
latent=latent, denoise=denoise,
|
||||||
frames_per_chunk=attempt_frames_per_chunk,
|
frames_per_chunk=attempt_frames_per_chunk,
|
||||||
temporal_overlap=temporal_overlap,
|
temporal_overlap=temporal_overlap,
|
||||||
chunking_mode="manual",
|
chunking_mode="manual",
|
||||||
@ -1006,12 +905,12 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
|||||||
# Short-circuit: total fits in one chunk -> standard path with no
|
# Short-circuit: total fits in one chunk -> standard path with no
|
||||||
# chunking overhead. Output of this branch is byte-identical to the
|
# chunking overhead. Output of this branch is byte-identical to the
|
||||||
# built-in KSampler given the same (model, seed, steps, cfg,
|
# built-in KSampler given the same (model, seed, steps, cfg,
|
||||||
# sampler_name, scheduler, positive, negative, latent_image,
|
# sampler_name, scheduler, positive, negative, latent,
|
||||||
# denoise) tuple.
|
# denoise) tuple.
|
||||||
if T_pixel <= frames_per_chunk:
|
if T_pixel <= frames_per_chunk:
|
||||||
return io.NodeOutput(_run_standard_sample(
|
return io.NodeOutput(_run_standard_sample(
|
||||||
model, seed, steps, cfg, sampler_name, scheduler,
|
model, seed, steps, cfg, sampler_name, scheduler,
|
||||||
positive, negative, latent_image, denoise,
|
positive, negative, latent, denoise,
|
||||||
))
|
))
|
||||||
|
|
||||||
# Map pixel chunk -> latent chunk. Each chunk's latent length is
|
# Map pixel chunk -> latent chunk. Each chunk's latent length is
|
||||||
@ -1038,10 +937,10 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
|||||||
# per chunk) preserves seed-determinism across chunk-count
|
# per chunk) preserves seed-determinism across chunk-count
|
||||||
# variations: the same (seed, total T_latent) always produces the
|
# variations: the same (seed, total T_latent) always produces the
|
||||||
# same noise samples regardless of how the work is partitioned.
|
# same noise samples regardless of how the work is partitioned.
|
||||||
batch_inds = latent_image.get("batch_index", None)
|
batch_inds = latent.get("batch_index", None)
|
||||||
noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds)
|
noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds)
|
||||||
|
|
||||||
noise_mask = latent_image.get("noise_mask", None)
|
noise_mask = latent.get("noise_mask", None)
|
||||||
|
|
||||||
# Build the flat list of chunk ranges first so the chunking
|
# Build the flat list of chunk ranges first so the chunking
|
||||||
# geometry is fully known before any sample call.
|
# geometry is fully known before any sample call.
|
||||||
@ -1096,7 +995,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
|||||||
chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap,
|
chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = latent_image.copy()
|
out = latent.copy()
|
||||||
out.pop("downscale_ratio_spacial", None)
|
out.pop("downscale_ratio_spacial", None)
|
||||||
out["samples"] = final
|
out["samples"] = final
|
||||||
return io.NodeOutput(out)
|
return io.NodeOutput(out)
|
||||||
@ -1107,8 +1006,7 @@ class SeedVRExtension(ComfyExtension):
|
|||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
SeedVR2Conditioning,
|
SeedVR2Conditioning,
|
||||||
SeedVR2Resize,
|
SeedVR2Preprocess,
|
||||||
SeedVR2ResizeAdvanced,
|
|
||||||
SeedVR2PostProcessing,
|
SeedVR2PostProcessing,
|
||||||
SeedVR2ProgressiveSampler,
|
SeedVR2ProgressiveSampler,
|
||||||
]
|
]
|
||||||
|
|||||||
@ -123,7 +123,7 @@ def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
|
|||||||
"model",
|
"model",
|
||||||
"vae_conditioning",
|
"vae_conditioning",
|
||||||
]
|
]
|
||||||
assert schema.inputs[1].display_name == "LATENT"
|
assert schema.inputs[1].display_name == "latent"
|
||||||
assert [output.display_name for output in schema.outputs] == [
|
assert [output.display_name for output in schema.outputs] == [
|
||||||
"model",
|
"model",
|
||||||
"positive",
|
"positive",
|
||||||
|
|||||||
@ -10,21 +10,6 @@ from comfy.cli_args import args as cli_args
|
|||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
cli_args.cpu = True
|
cli_args.cpu = True
|
||||||
|
|
||||||
import comfy_extras.nodes_seedvr as nodes_seedvr # noqa: E402
|
|
||||||
|
|
||||||
|
|
||||||
def test_resize_simple_multiplier_resolves_upscaled_shorter_edge():
|
|
||||||
images = torch.zeros(1, 3, 16, 20, 3)
|
|
||||||
|
|
||||||
output = nodes_seedvr.SeedVR2Resize.execute(images, 4.0)
|
|
||||||
|
|
||||||
input_pixels, original_image, upscaled_shorter_edge = output.result
|
|
||||||
assert tuple(input_pixels.shape) == (1, 5, 64, 80, 3)
|
|
||||||
assert input_pixels.min().item() == 0.0
|
|
||||||
assert input_pixels.max().item() == 0.0
|
|
||||||
assert original_image is images
|
|
||||||
assert upscaled_shorter_edge == 64
|
|
||||||
|
|
||||||
|
|
||||||
def test_seedvr_node_signature_matches_schema():
|
def test_seedvr_node_signature_matches_schema():
|
||||||
mock_mm = MagicMock()
|
mock_mm = MagicMock()
|
||||||
@ -46,7 +31,7 @@ def test_seedvr_node_signature_matches_schema():
|
|||||||
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||||
try:
|
try:
|
||||||
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
|
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
|
||||||
for node_cls in (nodes_seedvr.SeedVR2Resize, nodes_seedvr.SeedVR2ResizeAdvanced):
|
for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler):
|
||||||
schema_ids = [i.id for i in node_cls.define_schema().inputs]
|
schema_ids = [i.id for i in node_cls.define_schema().inputs]
|
||||||
exec_params = [
|
exec_params = [
|
||||||
p for p in inspect.signature(node_cls.execute).parameters.keys()
|
p for p in inspect.signature(node_cls.execute).parameters.keys()
|
||||||
|
|||||||
@ -17,12 +17,9 @@ def _schema_ids(items):
|
|||||||
def test_seedvr2_post_processing_schema():
|
def test_seedvr2_post_processing_schema():
|
||||||
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
|
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
|
||||||
|
|
||||||
assert _schema_ids(schema.inputs) == ["decoded", "original_image", "upscaled_shorter_edge", "color_correction_method"]
|
assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"]
|
||||||
assert schema.inputs[2].default is None
|
assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"]
|
||||||
assert schema.inputs[2].min == 2
|
assert schema.inputs[2].default == "lab"
|
||||||
assert schema.inputs[2].force_input is True
|
|
||||||
assert schema.inputs[3].options == ["lab", "wavelet", "adain", "none"]
|
|
||||||
assert schema.inputs[3].default == "lab"
|
|
||||||
assert schema.outputs[0].get_io_type() == "IMAGE"
|
assert schema.outputs[0].get_io_type() == "IMAGE"
|
||||||
|
|
||||||
|
|
||||||
@ -53,7 +50,7 @@ def test_seedvr2_post_processing_unknown_color_correction_method_raises():
|
|||||||
decoded = torch.zeros(1, 2, 4, 4, 3)
|
decoded = torch.zeros(1, 2, 4, 4, 3)
|
||||||
original = torch.zeros(1, 2, 4, 4, 3)
|
original = torch.zeros(1, 2, 4, 4, 3)
|
||||||
try:
|
try:
|
||||||
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "bogus")
|
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus")
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
assert "color_correction_method" in str(exc)
|
assert "color_correction_method" in str(exc)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -83,7 +83,7 @@ def test_auto_chunking_walks_two_three_four_chunk_ladder():
|
|||||||
out = SeedVR2ProgressiveSampler.execute(
|
out = SeedVR2ProgressiveSampler.execute(
|
||||||
model=None, seed=0, steps=2, cfg=1.0,
|
model=None, seed=0, steps=2, cfg=1.0,
|
||||||
sampler_name="euler", scheduler="simple",
|
sampler_name="euler", scheduler="simple",
|
||||||
positive=pos, negative=neg, latent_image=latent,
|
positive=pos, negative=neg, latent=latent,
|
||||||
denoise=1.0, frames_per_chunk=65, temporal_overlap=0,
|
denoise=1.0, frames_per_chunk=65, temporal_overlap=0,
|
||||||
chunking_mode="auto",
|
chunking_mode="auto",
|
||||||
)
|
)
|
||||||
@ -119,7 +119,7 @@ def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk):
|
|||||||
SeedVR2ProgressiveSampler.execute(
|
SeedVR2ProgressiveSampler.execute(
|
||||||
model=None, seed=0, steps=2, cfg=1.0,
|
model=None, seed=0, steps=2, cfg=1.0,
|
||||||
sampler_name="euler", scheduler="simple",
|
sampler_name="euler", scheduler="simple",
|
||||||
positive=pos, negative=neg, latent_image=latent,
|
positive=pos, negative=neg, latent=latent,
|
||||||
denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0,
|
denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0,
|
||||||
)
|
)
|
||||||
assert str(bad_chunk) in str(excinfo.value)
|
assert str(bad_chunk) in str(excinfo.value)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user