diff --git a/comfy/ldm/seedvr/constants.py b/comfy/ldm/seedvr/constants.py index 71b71d4ad..df91c7772 100644 --- a/comfy/ldm/seedvr/constants.py +++ b/comfy/ldm/seedvr/constants.py @@ -8,26 +8,13 @@ Provenance prefixes: ISO / CIE values; cite the standard. """ -# -------------------------------------------------------------------------------------- -# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment) -# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN) -# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070 -# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT). -# -------------------------------------------------------------------------------------- -SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB) -SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB - -# -------------------------------------------------------------------------------------- -# B. Fork heuristics (SEEDVR2 - this integration) -# -------------------------------------------------------------------------------------- SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim. # (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.) -SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry. +SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # OOM retry backoff: halve the chunk and retry. SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case). SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM. SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk. SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels). -SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16). # Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing) SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk. @@ -36,7 +23,7 @@ SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. # -------------------------------------------------------------------------------------- -# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) +# ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) # -------------------------------------------------------------------------------------- BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift. @@ -56,7 +43,7 @@ BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max freq BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim). # -------------------------------------------------------------------------------------- -# D. Published standards (cite the literature) +# Published standards (cite the literature) # -------------------------------------------------------------------------------------- ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864. diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 1fb44ac36..bf5b3c15c 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -1,12 +1,8 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, io import torch -import math -import logging import comfy.model_management -import comfy.sample -import comfy.samplers from comfy.ldm.seedvr.color_fix import ( adain_color_transfer, lab_color_transfer, @@ -14,10 +10,7 @@ from comfy.ldm.seedvr.color_fix import ( ) from comfy.ldm.seedvr.constants import ( SEEDVR2_ADAIN_SCALE_MULTIPLIER, - SEEDVR2_CHUNK_FRAMES_PER_GB, - SEEDVR2_CHUNK_GB_MARGIN, SEEDVR2_COLOR_MEM_HEADROOM, - SEEDVR2_COND_CHANNELS, SEEDVR2_DTYPE_BYTES_FLOOR, SEEDVR2_LAB_SCALE_MULTIPLIER, SEEDVR2_LATENT_CHANNELS, @@ -39,40 +32,6 @@ _SEEDVR2_INVALID_MODEL_MSG_PREFIX = ( _ATTR_MISSING = object() -def _seedvr2_vram_seed_frames_per_chunk(free_bytes, t_pixel): - """Predict the largest 4n+1 pixel-frame chunk that fits in free_bytes.""" - free_gb = free_bytes / (1024 ** 3) - predicted = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_gb - SEEDVR2_CHUNK_GB_MARGIN) - # round (not floor) to 4n+1: the fit's central prediction lands on measured n_max - n = round((predicted - 1) / 4) - seed = 4 * int(n) + 1 - seed = max(1, min(seed, t_pixel)) - return seed - - -def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk): - """Return stricter 4n+1 frame chunk sizes for auto OOM retries.""" - attempts = [frames_per_chunk] - current_chunk_latent = ( - t_latent if t_pixel <= frames_per_chunk - else (frames_per_chunk - 1) // 4 + 1 - ) - current_chunk_count = max(1, math.ceil(t_latent / current_chunk_latent)) - seen = {frames_per_chunk} - - for target_chunks in range(max(2, current_chunk_count + 1), t_latent + 1): - chunk_latent = max(1, math.ceil(t_latent / target_chunks)) - candidate = 4 * (chunk_latent - 1) + 1 - if candidate in seen: - continue - if candidate >= attempts[-1]: - continue - attempts.append(candidate) - seen.add(candidate) - - return attempts - - def _resolve_seedvr2_diffusion_model(model): """Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message.""" inner = getattr(model, "model", _ATTR_MISSING) @@ -473,478 +432,6 @@ class SeedVR2Conditioning(io.ComfyNode): return io.NodeOutput(model_patcher, positive, negative, {"samples": latent}) -def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int, - t_end: int, channels: int) -> torch.Tensor: - """Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse.""" - B, CT, H, W = tensor_4d.shape - if CT % channels != 0: - raise ValueError( - f"_slice_collapsed_4d_along_t: collapsed channel dim {CT} is not " - f"divisible by channels={channels}; tensor shape {tuple(tensor_4d.shape)}." - ) - T = CT // channels - if not (0 <= t_start < t_end <= T): - raise ValueError( - f"_slice_collapsed_4d_along_t: slice [{t_start}:{t_end}] out of " - f"range for T={T}." - ) - new_T = t_end - t_start - sliced = tensor_4d.reshape(B, channels, T, H, W)[:, :, t_start:t_end, :, :].contiguous() - return sliced.reshape(B, channels * new_T, H, W) - - -def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int): - """Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated.""" - new_list = [] - for entry in cond_list: - text_cond, options = entry[0], entry[1] - if "condition" not in options: - new_list.append(entry) - continue - new_options = options.copy() - new_options["condition"] = _slice_collapsed_4d_along_t( - new_options["condition"], t_start, t_end, - SEEDVR2_COND_CHANNELS, - ) - new_list.append([text_cond, new_options]) - return new_list - - -def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor, - samples_4d: torch.Tensor, - t_start: int, - t_end: int): - """Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand.""" - if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]: - return _slice_collapsed_4d_along_t( - noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS, - ) - return noise_mask - - -def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor: - """Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D.""" - if len(chunks_4d) == 0: - raise ValueError("_concat_chunks_along_t: empty chunk list.") - fives = [] - for ch in chunks_4d: - B, CT, H, W = ch.shape - if CT % channels != 0: - raise ValueError( - f"_concat_chunks_along_t: chunk shape {tuple(ch.shape)} " - f"channel dim {CT} not divisible by channels={channels}." - ) - T = CT // channels - fives.append(ch.reshape(B, channels, T, H, W)) - cat = torch.cat(fives, dim=2).contiguous() - B, C, T_total, H, W = cat.shape - return cat.reshape(B, C * T_total, H, W) - - -def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor: - """1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``): - Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3`` - (dead-band would collapse a tiny transition). Window shape matched to the reference - overlapping-frame blend for parity; caller broadcasts across ``(B, C, T_overlap, H, W)``. - """ - if overlap < 1: - raise ValueError( - f"_hann_blend_weights_1d: overlap must be >= 1; got {overlap}." - ) - if overlap >= 3: - t = torch.linspace(0.0, 1.0, steps=overlap, device=device, dtype=dtype) - blend_start = 1.0 / 3.0 - blend_end = 2.0 / 3.0 - u = ((t - blend_start) / (blend_end - blend_start)).clamp(0.0, 1.0) - return 0.5 + 0.5 * torch.cos(torch.pi * u) - return torch.linspace(1.0, 0.0, steps=overlap, device=device, dtype=dtype) - - -def _blend_overlap_region(prev_tail_5d: torch.Tensor, - cur_head_5d: torch.Tensor) -> torch.Tensor: - """Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device).""" - if prev_tail_5d.shape != cur_head_5d.shape: - raise ValueError( - f"_blend_overlap_region: shape mismatch " - f"prev {tuple(prev_tail_5d.shape)} vs " - f"cur {tuple(cur_head_5d.shape)}." - ) - overlap = int(prev_tail_5d.shape[2]) - w_prev_1d = _hann_blend_weights_1d( - overlap, prev_tail_5d.device, prev_tail_5d.dtype, - ) - # Reshape to (1, 1, overlap, 1, 1) for broadcast across B, C, H, W. - w_prev = w_prev_1d.view(1, 1, overlap, 1, 1) - w_cur = 1.0 - w_prev - return prev_tail_5d * w_prev + cur_head_5d * w_cur - - -def _concat_chunks_with_overlap_blend(chunk_specs, channels: int, - overlap_latent: int) -> torch.Tensor: - """Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk.""" - if len(chunk_specs) == 0: - raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.") - if overlap_latent < 0: - raise ValueError( - f"_concat_chunks_with_overlap_blend: overlap_latent must be " - f">= 0; got {overlap_latent}." - ) - - # Validate channel divisibility once and capture per-chunk T. - chunk_5d = [] - for t_start, t_end, ch in chunk_specs: - B, CT, H, W = ch.shape - if CT % channels != 0: - raise ValueError( - f"_concat_chunks_with_overlap_blend: chunk shape " - f"{tuple(ch.shape)} channel dim {CT} not divisible " - f"by channels={channels}." - ) - T = CT // channels - if t_end - t_start != T: - raise ValueError( - f"_concat_chunks_with_overlap_blend: chunk T={T} mismatches " - f"declared range [{t_start}:{t_end}]." - ) - chunk_5d.append((t_start, t_end, ch.reshape(B, channels, T, H, W))) - - if overlap_latent == 0: - # Fast path: pure concat in the caller-provided chunk order. - return _concat_chunks_along_t( - [c.reshape(c.shape[0], channels * c.shape[2], c.shape[3], c.shape[4]) - for _, _, c in chunk_5d], - channels, - ) - - T_total = max(t_end for _, t_end, _ in chunk_5d) - first_5d = chunk_5d[0][2] - B = first_5d.shape[0] - H = first_5d.shape[3] - W = first_5d.shape[4] - result = torch.empty( - (B, channels, T_total, H, W), - device=first_5d.device, dtype=first_5d.dtype, - ) - filled_until = 0 - for i, (cs, ce, ct_5d) in enumerate(chunk_5d): - chunk_T = int(ct_5d.shape[2]) - if i == 0: - result[:, :, cs:ce, :, :] = ct_5d - filled_until = ce - continue - # Overlap region width is bounded by both the previous fill - # frontier and the current chunk's actual length (for runt - # final chunks shorter than the configured overlap). - overlap_len = min(filled_until - cs, chunk_T) - if overlap_len > 0: - prev_tail = result[:, :, cs:cs + overlap_len, :, :].contiguous() - cur_head = ct_5d[:, :, :overlap_len, :, :].contiguous() - blended = _blend_overlap_region(prev_tail, cur_head) - result[:, :, cs:cs + overlap_len, :, :] = blended - tail_start = cs + overlap_len - tail_end = ce - if tail_end > tail_start: - result[:, :, tail_start:tail_end, :, :] = ( - ct_5d[:, :, overlap_len:, :, :] - ) - else: - # Disjoint chunks (overlap_latent set but this pair did not - # actually overlap, e.g. step_latent equal to chunk_latent - # in a degenerate config). Treat as concat. - result[:, :, cs:ce, :, :] = ct_5d - filled_until = ce - - return result.contiguous().reshape(B, channels * T_total, H, W) - - -def _run_standard_sample(model, seed: int, steps: int, cfg: float, - sampler_name: str, scheduler: str, - positive, negative, latent: dict, - denoise: float) -> dict: - """Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk.""" - samples_in = latent["samples"] - samples_in = comfy.sample.fix_empty_latent_channels( - model, samples_in, latent.get("downscale_ratio_spacial", None), - ) - batch_inds = latent.get("batch_index", None) - noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds) - noise_mask = latent.get("noise_mask", None) - samples = comfy.sample.sample( - model, noise, steps, cfg, sampler_name, scheduler, - positive, negative, samples_in, - denoise=denoise, noise_mask=noise_mask, seed=seed, - ) - out = latent.copy() - out.pop("downscale_ratio_spacial", None) - out["samples"] = samples - return out - - -class SeedVR2ProgressiveSampler(io.ComfyNode): - """Sequential temporal chunking sampler for SeedVR2 native. - - Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that - OOM on long sequences. The latent enters the sampler in SeedVR2's - collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning`` - at ``reshape(b, c * t, h, w)``); this node slices that - tensor along the temporal axis, runs the configured inner sampler - sequentially per chunk against the standard ``comfy.sample.sample`` - entry point, and concatenates per-chunk outputs back into a single - ``(B, 16*T_total, H, W)`` latent. - - ``frames_per_chunk`` is expressed in pixel-frame units to match the - SeedVR2 4n+1 constraint enforced upstream by ``cut_videos`` and the - VAE's ``temporal_downsample_factor=4``. A pixel chunk size ``F`` - maps to ``(F - 1) // 4 + 1`` latent-frame chunks. - - Determinism contract: a single noise tensor is generated once from - the user seed and sliced per chunk (rather than re-seeding each - chunk), so a workflow that fits in a single chunk produces output - identical to a workflow that fits in N chunks at the same seed, - modulo the inherent T-axis chunk-boundary independence of the model. - """ - - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2ProgressiveSampler", - display_name="Sample SeedVR2 (Progressive)", - category="sampling", - description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.", - search_aliases=["seedvr2", "upscale", "video upscale", "sampler", "chunk"], - inputs=[ - io.Model.Input("model", tooltip="The model used for denoising the input latent."), - io.Int.Input("seed", default=0, min=0, - max=0xffffffffffffffff, - control_after_generate=True, - tooltip="The random seed used for creating the noise."), - io.Int.Input("steps", default=20, min=1, max=10000, - tooltip="The number of steps used in the denoising process."), - io.Float.Input("cfg", default=1.0, min=0.0, max=100.0, - step=0.1, round=0.01, - tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."), - io.Combo.Input("sampler_name", - options=comfy.samplers.SAMPLER_NAMES, - tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."), - io.Combo.Input("scheduler", - options=comfy.samplers.SCHEDULER_NAMES, - tooltip="The scheduler controls how noise is gradually removed to form the image."), - io.Conditioning.Input("positive", - tooltip="The conditioning describing the attributes you want to include in the image."), - io.Conditioning.Input("negative", - tooltip="The conditioning describing the attributes you want to exclude from the image."), - io.Latent.Input("latent", - tooltip="The latent image to denoise."), - io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, - step=0.01, - tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."), - io.Int.Input("frames_per_chunk", default=21, min=1, - max=16384, step=4, - tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."), - io.Int.Input("temporal_overlap", default=0, min=0, - max=16384, - tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."), - io.Combo.Input("chunking_mode", - options=["manual", "auto"], - default="manual", - tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."), - ], - outputs=[io.Latent.Output(display_name="latent", tooltip="The upscaled latent.")], - ) - - @classmethod - def execute(cls, model, seed, steps, cfg, sampler_name, scheduler, - positive, negative, latent, denoise, - frames_per_chunk, temporal_overlap, - chunking_mode="manual") -> io.NodeOutput: - # 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline - # requires pixel-frame counts of the form 4n+1 (1, 5, 9, 13, ...), - # imposed at ``cut_videos`` upstream and propagated through the VAE's - # temporal_downsample_factor=4. Reject violations explicitly before - # any model invocation; a silent rounding would mis-align chunk - # boundaries with the 4n+1 lattice. - if frames_per_chunk < 1 or (frames_per_chunk - 1) % 4 != 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: frames_per_chunk must be a " - f"4n+1 pixel-frame count (1, 5, 9, 13, 17, 21, ...); " - f"got {frames_per_chunk}." - ) - - samples_4d = latent["samples"] - if torch.count_nonzero(samples_4d) == 0: - raise ValueError( - "SeedVR2ProgressiveSampler: input latent is empty (all zeros). " - "SeedVR2 is an upscaler; connect an encoded latent from " - "'Apply SeedVR2 conditioning' rather than an empty latent." - ) - samples_4d = comfy.sample.fix_empty_latent_channels( - model, samples_4d, - latent.get("downscale_ratio_spacial", None), - ) - if samples_4d.ndim != 4: - raise ValueError( - f"SeedVR2ProgressiveSampler: expected 4D collapsed latent " - f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}." - ) - B, CT, H, W = samples_4d.shape - if CT % SEEDVR2_LATENT_CHANNELS != 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is " - f"not divisible by SeedVR2 latent channels " - f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be " - f"SeedVR2-shaped." - ) - T_latent = CT // SEEDVR2_LATENT_CHANNELS - T_pixel = 4 * (T_latent - 1) + 1 - - if chunking_mode not in ("manual", "auto"): - raise ValueError( - f"SeedVR2ProgressiveSampler: chunking_mode must be " - f"'manual' or 'auto'; got {chunking_mode!r}." - ) - - if chunking_mode == "auto": - free_memory = comfy.model_management.get_free_memory(model.load_device) - seed_frames_per_chunk = _seedvr2_vram_seed_frames_per_chunk( - free_memory, T_pixel, - ) - logging.info( - "SeedVR2ProgressiveSampler auto: free=%.2fGB -> seeding " - "frames_per_chunk=%s (4n+1; T_pixel=%s).", - free_memory / (1024 ** 3), seed_frames_per_chunk, T_pixel, - ) - attempts = _seedvr2_auto_chunk_attempts( - T_latent, T_pixel, seed_frames_per_chunk, - ) - for i, attempt_frames_per_chunk in enumerate(attempts): - retry = False - try: - return cls.execute( - model=model, seed=seed, steps=steps, cfg=cfg, - sampler_name=sampler_name, scheduler=scheduler, - positive=positive, negative=negative, - latent=latent, denoise=denoise, - frames_per_chunk=attempt_frames_per_chunk, - temporal_overlap=temporal_overlap, - chunking_mode="manual", - ) - except Exception as e: - comfy.model_management.raise_non_oom(e) - if i == len(attempts) - 1: - raise RuntimeError( - "SeedVR2ProgressiveSampler: exhausted auto " - "chunking attempts after OOM. Tried " - f"frames_per_chunk values {attempts}." - ) from e - retry = True - - if retry: - logging.warning( - "SeedVR2ProgressiveSampler auto chunking OOM at " - "frames_per_chunk=%s; retrying with " - "frames_per_chunk=%s.", - attempt_frames_per_chunk, attempts[i + 1], - ) - - # Short-circuit: total fits in one chunk -> standard path with no - # chunking overhead. Output of this branch is byte-identical to the - # built-in KSampler given the same (model, seed, steps, cfg, - # sampler_name, scheduler, positive, negative, latent, - # denoise) tuple. - if T_pixel <= frames_per_chunk: - return io.NodeOutput(_run_standard_sample( - model, seed, steps, cfg, sampler_name, scheduler, - positive, negative, latent, denoise, - )) - - # Map pixel chunk -> latent chunk. Each chunk's latent length is - # at most ``chunk_latent``; the final chunk may be a runt that - # is automatically 4n+1-aligned in the pixel domain by the - # T_pixel = 4*(T_latent-1) + 1 mapping (every positive integer - # T_latent corresponds to a valid 4n+1 pixel count). - chunk_latent = (frames_per_chunk - 1) // 4 + 1 - - # ``temporal_overlap`` is exposed in latent-frame units, but users - # do not know the derived latent chunk length. Treat oversized - # values as "maximum valid overlap" while preserving a strictly - # positive chunk-loop stride. - if temporal_overlap < 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: temporal_overlap must be >= 0; " - f"got {temporal_overlap}." - ) - temporal_overlap = min(temporal_overlap, chunk_latent - 1) - step_latent = chunk_latent - temporal_overlap - - # Generate full noise once from the user seed, then slice along T - # per chunk. Using one global noise tensor (rather than re-seeding - # per chunk) preserves seed-determinism across chunk-count - # variations: the same (seed, total T_latent) always produces the - # same noise samples regardless of how the work is partitioned. - batch_inds = latent.get("batch_index", None) - noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds) - - noise_mask = latent.get("noise_mask", None) - - # Build the flat list of chunk ranges first so the chunking - # geometry is fully known before any sample call. - chunk_ranges = [] - for chunk_start in range(0, T_latent, step_latent): - chunk_end = min(chunk_start + chunk_latent, T_latent) - if chunk_start >= chunk_end: - # The final iteration of a stride that lands exactly on - # T_latent produces a zero-length chunk; skip it. - break - chunk_ranges.append((chunk_start, chunk_end)) - if chunk_end >= T_latent: - break - - def _sample_one_chunk(chunk_start, chunk_end): - samples_chunk = _slice_collapsed_4d_along_t( - samples_4d, chunk_start, chunk_end, - SEEDVR2_LATENT_CHANNELS, - ) - noise_chunk = _slice_collapsed_4d_along_t( - noise_full, chunk_start, chunk_end, - SEEDVR2_LATENT_CHANNELS, - ) - positive_chunk = _slice_seedvr2_cond_along_t( - positive, chunk_start, chunk_end, - ) - negative_chunk = _slice_seedvr2_cond_along_t( - negative, chunk_start, chunk_end, - ) - - # Per-chunk noise_mask handling: standard masks are passed - # through for KSampler expansion; pre-expanded collapsed - # masks are sliced. - chunk_noise_mask = None - if noise_mask is not None: - chunk_noise_mask = _slice_seedvr2_noise_mask_along_t( - noise_mask, samples_4d, chunk_start, chunk_end, - ) - - return comfy.sample.sample( - model, noise_chunk, steps, cfg, sampler_name, scheduler, - positive_chunk, negative_chunk, samples_chunk, - denoise=denoise, noise_mask=chunk_noise_mask, seed=seed, - ) - - chunk_specs = [] - for chunk_start, chunk_end in chunk_ranges: - chunk_samples = _sample_one_chunk(chunk_start, chunk_end) - chunk_specs.append((chunk_start, chunk_end, chunk_samples)) - - final = _concat_chunks_with_overlap_blend( - chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap, - ) - - out = latent.copy() - out.pop("downscale_ratio_spacial", None) - out["samples"] = final - return io.NodeOutput(out) - - class SeedVRExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -952,7 +439,6 @@ class SeedVRExtension(ComfyExtension): SeedVR2Conditioning, SeedVR2Preprocess, SeedVR2PostProcessing, - SeedVR2ProgressiveSampler, ] async def comfy_entrypoint() -> SeedVRExtension: diff --git a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py index f7d9a4f65..1c5d20ac9 100644 --- a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py +++ b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py @@ -31,7 +31,7 @@ def test_seedvr_node_signature_matches_schema(): sys.modules.pop("comfy_extras.nodes_seedvr", None) try: nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr") - for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler): + for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning): schema_ids = [i.id for i in node_cls.define_schema().inputs] exec_params = [ p for p in inspect.signature(node_cls.execute).parameters.keys() diff --git a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py deleted file mode 100644 index 146b81225..000000000 --- a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``.""" - -from unittest.mock import patch - -import pytest -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.sample # noqa: E402 -import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402 -from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402 - -_LAT_C = 16 -_COND_C = 17 - - -def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8): - """Build minimal SeedVR2-shaped sampling inputs.""" - samples_5d = torch.arange( - B * _LAT_C * T * H * W, dtype=torch.float32 - ).reshape(B, _LAT_C, T, H, W) - samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous() - - cond_5d = torch.arange( - B * _COND_C * T * H * W, dtype=torch.float32 - ).reshape(B, _COND_C, T, H, W) + 10000.0 - cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous() - - text_pos = torch.zeros(1, 4, 32) - text_neg = torch.zeros(1, 4, 32) - positive = [[text_pos, {"condition": cond.clone()}]] - negative = [[text_neg, {"condition": cond.clone()}]] - latent_image = {"samples": samples} - return latent_image, positive, negative, samples_5d, cond_5d - - -def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None): - return latent_image - - -def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None): - """Return a tensor whose values encode ``(seed, position)``.""" - base = torch.arange( - latent_image.numel(), dtype=torch.float32 - ).reshape(latent_image.shape) - return base + float(seed) * 1e6 - - -def test_progressive_sampler_schema_exposes_manual_default_auto_chunking(): - schema = SeedVR2ProgressiveSampler.define_schema() - inputs = {item.id: item for item in schema.inputs} - - assert inputs["chunking_mode"].options == ["manual", "auto"] - assert inputs["chunking_mode"].default == "manual" - - -def test_vram_seed_frames_per_chunk_predicts_4n1_clamped_to_t_pixel(): - """VRAM chunk-size law: seed = nearest 4n+1 to 4*(free_GB - 3), clamped to [1, t_pixel].""" - gib = 1024 ** 3 - seed = nodes_seedvr_mod._seedvr2_vram_seed_frames_per_chunk - assert seed(20 * gib, 65) == 65 # 4*(20-3)=68 -> 4n+1 69 -> clamp to t_pixel 65 - assert seed(6 * gib, 97) == 13 # 4*(6-3)=12 -> nearest 4n+1 13 - assert seed(2 * gib, 97) == 1 # below margin -> floor at 1 - - -@pytest.mark.parametrize("bad_chunk", [0, -1, 2]) -def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk): - """``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation.""" - latent, pos, neg, _, _ = _make_inputs(T=5) - - sampler_called = {"n": 0} - - def _should_not_be_called(*args, **kwargs): - sampler_called["n"] += 1 - return torch.zeros(1) - - with patch.object(comfy.sample, "sample", - side_effect=_should_not_be_called), \ - patch.object(comfy.sample, "fix_empty_latent_channels", - side_effect=_identity_fix_empty), \ - patch.object(comfy.sample, "prepare_noise", - side_effect=_fingerprinted_prepare_noise): - with pytest.raises(ValueError) as excinfo: - SeedVR2ProgressiveSampler.execute( - model=None, seed=0, steps=2, cfg=1.0, - sampler_name="euler", scheduler="simple", - positive=pos, negative=neg, latent=latent, - denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0, - ) - assert str(bad_chunk) in str(excinfo.value) - assert sampler_called["n"] == 0