From a0a055bc4e4f2878c106bf8cf69c1aaa30f8b840 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Mon, 8 Jun 2026 14:27:50 -0700 Subject: [PATCH 1/8] feat(assets): extract image dimensions at ingest and emit on asset responses (#13991) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(assets): extract image dimensions at ingest and emit on asset responses Image assets now carry width/height under the existing `metadata` field on asset responses, shaped as `{"kind": "image", "width": W, "height": H}`. This lets consumers get original dimensions (e.g. for clients that render server-side thumbnails and can't recover them from naturalWidth/Height) without an extra round-trip. Dimensions are written to AssetReference.system_metadata across three ingest paths: - Direct file ingest (upload, in-place registration): Pillow reads the image header right after hashing, while the file is still in OS page cache. Non-image MIME types are skipped without touching the file. - From-hash registration: this path never reads the file bytes, so dimensions are best-effort copied from any prior sibling reference of the same asset that already carries kind=image metadata. Missing siblings, non-image siblings, or absent dimension keys leave the new reference's metadata unchanged. - Scanner enrichment: extends the existing system_metadata write in enrich_asset so scanner-registered images get the same treatment as uploaded ones. Existing system_metadata keys (e.g. safetensors fields written by the enricher, download provenance) are preserved through merge. Existing assets ingested before this change retain their current metadata — no automatic backfill in this PR. Tests cover image emission, non-image no-op, merge preservation, and the from-hash sibling back-fill (including the no-sibling and non-image-sibling cases). * fix(assets): validate sibling dimensions before backfilling Per CodeRabbit review on #13991: the previous loop accepted any sibling with `kind == "image"` and copied whichever dimension keys happened to be present, then returned. A partial sibling (kind set but missing or invalid width/height) could persist incomplete metadata onto the new reference even when a later sibling had valid dimensions. Now we validate that the sibling has both width and height as positive integers before adopting its dimensions, and continue scanning to the next sibling otherwise. * fix(assets): reject booleans in sibling dimension validation (use type-is) Per CodeRabbit follow-up on #13991: bool is a subclass of int in Python, so isinstance(True, int) is True. The previous strict-int gate would have accepted width=True (truthy + > 0) as a valid dimension. Realistic occurrence is low (extract_image_dimensions returns proper ints, JSON doesn't serialize bools as numbers), but the validation gate exists for defense-in-depth so it should be actually strict. --------- Co-authored-by: guill --- app/assets/scanner.py | 5 + app/assets/services/image_dimensions.py | 63 ++++++ app/assets/services/ingest.py | 99 +++++++++ .../services/test_image_dimensions.py | 86 ++++++++ .../assets_test/services/test_ingest.py | 208 +++++++++++++++++- 5 files changed, 460 insertions(+), 1 deletion(-) create mode 100644 app/assets/services/image_dimensions.py create mode 100644 tests-unit/assets_test/services/test_image_dimensions.py diff --git a/app/assets/scanner.py b/app/assets/scanner.py index ebb6869af..495c30443 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -33,6 +33,7 @@ from app.assets.services.file_utils import ( verify_file_unchanged, ) from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash +from app.assets.services.image_dimensions import extract_image_dimensions from app.assets.services.metadata_extract import extract_file_metadata from app.assets.services.path_utils import ( compute_relative_filename, @@ -506,6 +507,10 @@ def enrich_asset( if extract_metadata and metadata: system_metadata = metadata.to_user_metadata() + if mime_type and mime_type.startswith("image/"): + dims = extract_image_dimensions(file_path, mime_type=mime_type) + if dims: + system_metadata.update(dims) set_reference_system_metadata(session, reference_id, system_metadata) if full_hash: diff --git a/app/assets/services/image_dimensions.py b/app/assets/services/image_dimensions.py new file mode 100644 index 000000000..ccd97399a --- /dev/null +++ b/app/assets/services/image_dimensions.py @@ -0,0 +1,63 @@ +"""Image dimension extraction for asset ingest. + +Reads only the image header via Pillow to capture width/height cheaply, +without a full pixel decode. Returns a metadata dict suitable for merging +into ``AssetReference.system_metadata``. +""" +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def extract_image_dimensions( + file_path: str, mime_type: str | None = None +) -> dict[str, Any] | None: + """Extract image dimensions for the file at ``file_path``. + + Args: + file_path: Absolute path to a file on disk. + mime_type: Optional MIME type hint. When provided and not prefixed + with ``image/``, extraction is skipped without touching the file. + + Returns: + ``{"kind": "image", "width": W, "height": H}`` when the file is a + recognizable image with positive dimensions, otherwise ``None``. + + The dict shape is intended to be merged into ``system_metadata`` so the + asset response surfaces ``metadata.kind`` plus dimension fields for image + assets. Forward-compatible: future media kinds (e.g. ``"video"`` with + duration/fps) can extend this shape without schema changes. + """ + if mime_type is not None and not mime_type.startswith("image/"): + return None + + try: + from PIL import Image, UnidentifiedImageError + except ImportError: + logger.debug( + "Pillow not available; skipping image dimension extraction for %s", + file_path, + ) + return None + + try: + with Image.open(file_path) as img: + width, height = img.size + except (OSError, UnidentifiedImageError, ValueError) as exc: + logger.debug( + "Failed to read image dimensions from %s: %s", file_path, exc + ) + return None + + if ( + not isinstance(width, int) + or not isinstance(height, int) + or width <= 0 + or height <= 0 + ): + return None + + return {"kind": "image", "width": width, "height": height} diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index f0b070517..3b6dc237c 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -17,9 +17,11 @@ from app.assets.database.queries import ( get_reference_by_file_path, get_reference_tags, get_or_create_reference, + list_references_by_asset_id, reference_exists, remove_missing_tag_for_asset_id, set_reference_metadata, + set_reference_system_metadata, set_reference_tags, update_asset_hash_and_mime, upsert_asset, @@ -29,6 +31,7 @@ from app.assets.database.queries import ( from app.assets.helpers import get_utc_now, normalize_tags from app.assets.services.bulk_ingest import batch_insert_seed_assets from app.assets.services.file_utils import get_size_and_mtime_ns +from app.assets.services.image_dimensions import extract_image_dimensions from app.assets.services.path_utils import ( compute_relative_filename, get_name_and_tags_from_asset_path, @@ -118,6 +121,14 @@ def _ingest_file_from_path( user_metadata=user_metadata, ) + _maybe_store_image_dimensions( + session, + reference_id=reference_id, + file_path=locator, + mime_type=mime_type, + current_system_metadata=ref.system_metadata, + ) + try: remove_missing_tag_for_asset_id(session, asset_id=asset.id) except Exception: @@ -288,6 +299,13 @@ def _register_existing_asset( user_metadata=new_meta, ) + _backfill_image_dimensions_from_siblings( + session, + asset_id=asset.id, + new_reference_id=ref.id, + current_system_metadata=ref.system_metadata, + ) + if tags is not None: set_reference_tags( session, @@ -334,6 +352,87 @@ def _update_metadata_with_filename( ) +_IMAGE_DIMENSION_KEYS = ("kind", "width", "height") + + +def _maybe_store_image_dimensions( + session: Session, + reference_id: str, + file_path: str, + mime_type: str | None, + current_system_metadata: dict | None, +) -> None: + """Populate ``kind``/``width``/``height`` on system_metadata for image refs. + + Non-image MIME types are a no-op. Pre-existing keys (e.g. enricher-written + safetensors metadata, download provenance) are preserved by merge. + """ + if not mime_type or not mime_type.startswith("image/"): + return + + dims = extract_image_dimensions(file_path, mime_type=mime_type) + if not dims: + return + + current = current_system_metadata or {} + merged = dict(current) + merged.update(dims) + if merged != current: + set_reference_system_metadata( + session, + reference_id=reference_id, + system_metadata=merged, + ) + + +def _backfill_image_dimensions_from_siblings( + session: Session, + asset_id: str, + new_reference_id: str, + current_system_metadata: dict | None, +) -> None: + """Copy image dimension keys from any sibling reference of the same asset. + + The from-hash path doesn't read the file bytes, so dimensions can't be + extracted there directly. When another reference of the same asset already + carries image dimensions, copy them onto the new reference so consumers + see consistent metadata regardless of how the asset was registered. + + Best-effort: missing siblings, non-image siblings, or absent dimension + keys leave the target reference unchanged. + """ + current = current_system_metadata or {} + if current.get("kind") == "image" and "width" in current and "height" in current: + return + + for sibling in list_references_by_asset_id(session, asset_id): + if sibling.id == new_reference_id: + continue + meta = sibling.system_metadata or {} + if meta.get("kind") != "image": + continue + width = meta.get("width") + height = meta.get("height") + if ( + type(width) is not int + or type(height) is not int + or width <= 0 + or height <= 0 + ): + continue + merged = dict(current) + merged["kind"] = "image" + merged["width"] = width + merged["height"] = height + if merged != current: + set_reference_system_metadata( + session, + reference_id=new_reference_id, + system_metadata=merged, + ) + return + + def _sanitize_filename(name: str | None, fallback: str) -> str: n = os.path.basename((name or "").strip() or fallback) return n if n else fallback diff --git a/tests-unit/assets_test/services/test_image_dimensions.py b/tests-unit/assets_test/services/test_image_dimensions.py new file mode 100644 index 000000000..ac275eae2 --- /dev/null +++ b/tests-unit/assets_test/services/test_image_dimensions.py @@ -0,0 +1,86 @@ +"""Tests for the image_dimensions service.""" +from __future__ import annotations + +from pathlib import Path + +import pytest +from PIL import Image + +from app.assets.services.image_dimensions import extract_image_dimensions + + +def _make_png(path: Path, size: tuple[int, int]) -> Path: + img = Image.new("RGB", size, color=(123, 45, 67)) + img.save(path, format="PNG") + return path + + +def _make_jpeg(path: Path, size: tuple[int, int]) -> Path: + img = Image.new("RGB", size, color=(10, 20, 30)) + img.save(path, format="JPEG", quality=80) + return path + + +class TestExtractImageDimensions: + def test_extracts_png_dimensions(self, tmp_path: Path): + f = _make_png(tmp_path / "rect.png", (320, 240)) + + result = extract_image_dimensions(str(f), mime_type="image/png") + + assert result == {"kind": "image", "width": 320, "height": 240} + + def test_extracts_jpeg_dimensions(self, tmp_path: Path): + f = _make_jpeg(tmp_path / "shot.jpg", (1920, 1080)) + + result = extract_image_dimensions(str(f), mime_type="image/jpeg") + + assert result == {"kind": "image", "width": 1920, "height": 1080} + + def test_works_when_mime_type_is_none(self, tmp_path: Path): + f = _make_png(tmp_path / "no_mime.png", (50, 100)) + + result = extract_image_dimensions(str(f), mime_type=None) + + assert result == {"kind": "image", "width": 50, "height": 100} + + def test_skips_non_image_mime_without_touching_file(self, tmp_path: Path): + # Path doesn't need to exist — non-image MIME short-circuits. + result = extract_image_dimensions( + str(tmp_path / "model.safetensors"), + mime_type="application/octet-stream", + ) + + assert result is None + + @pytest.mark.parametrize( + "mime", + ["application/json", "text/plain", "video/mp4", "audio/mpeg"], + ) + def test_skips_all_non_image_mime_types(self, tmp_path: Path, mime: str): + f = tmp_path / "file.bin" + f.write_bytes(b"\x00\x01\x02") + + assert extract_image_dimensions(str(f), mime_type=mime) is None + + def test_returns_none_for_missing_file(self, tmp_path: Path): + result = extract_image_dimensions( + str(tmp_path / "does_not_exist.png"), mime_type="image/png" + ) + + assert result is None + + def test_returns_none_for_corrupt_image(self, tmp_path: Path): + f = tmp_path / "corrupt.png" + f.write_bytes(b"not actually a png file") + + result = extract_image_dimensions(str(f), mime_type="image/png") + + assert result is None + + def test_returns_none_for_empty_file(self, tmp_path: Path): + f = tmp_path / "empty.png" + f.write_bytes(b"") + + result = extract_image_dimensions(str(f), mime_type="image/png") + + assert result is None diff --git a/tests-unit/assets_test/services/test_ingest.py b/tests-unit/assets_test/services/test_ingest.py index b153f9795..12a3bdfe6 100644 --- a/tests-unit/assets_test/services/test_ingest.py +++ b/tests-unit/assets_test/services/test_ingest.py @@ -4,10 +4,12 @@ from pathlib import Path from unittest.mock import patch import pytest +from PIL import Image from sqlalchemy.orm import Session as SASession, Session from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag from app.assets.database.queries import get_reference_tags +from app.assets.helpers import get_utc_now from app.assets.services.ingest import ( _ingest_file_from_path, _register_existing_asset, @@ -15,6 +17,11 @@ from app.assets.services.ingest import ( ) +def _make_png(path: Path, size: tuple[int, int]) -> Path: + Image.new("RGB", size, color=(80, 120, 200)).save(path, format="PNG") + return path + + class TestIngestFileFromPath: def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session): file_path = temp_dir / "test_file.bin" @@ -279,4 +286,203 @@ class TestIngestExistingFileTagFK: ref_tags = sess.query(AssetReferenceTag).all() ref_tag_names = {rt.tag_name for rt in ref_tags} assert "output" in ref_tag_names - assert "my-job" in ref_tag_names + + +class TestIngestImageDimensions: + """system_metadata should carry {kind, width, height} for image assets.""" + + def test_image_asset_emits_dimensions( + self, mock_create_session, temp_dir: Path, session: Session + ): + f = _make_png(temp_dir / "shot.png", (640, 480)) + + result = _ingest_file_from_path( + abs_path=str(f), + asset_hash="blake3:img1", + size_bytes=f.stat().st_size, + mtime_ns=1234567890000000000, + mime_type="image/png", + ) + + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.system_metadata == { + "kind": "image", + "width": 640, + "height": 480, + } + + def test_non_image_asset_leaves_system_metadata_empty( + self, mock_create_session, temp_dir: Path, session: Session + ): + f = temp_dir / "model.safetensors" + f.write_bytes(b"not an image") + + result = _ingest_file_from_path( + abs_path=str(f), + asset_hash="blake3:safetensors1", + size_bytes=f.stat().st_size, + mtime_ns=1234567890000000000, + mime_type="application/octet-stream", + ) + + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.system_metadata in (None, {}) + + def test_preserves_existing_system_metadata_keys( + self, mock_create_session, temp_dir: Path, session: Session + ): + f = _make_png(temp_dir / "annotated.png", (100, 200)) + + # First pass populates a sentinel system_metadata key (simulating prior + # enricher write). + result = _ingest_file_from_path( + abs_path=str(f), + asset_hash="blake3:img-merge", + size_bytes=f.stat().st_size, + mtime_ns=1234567890000000000, + mime_type="image/png", + ) + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/x.png"} + session.commit() + + # Second pass with the same path triggers the merge code path again. + _ingest_file_from_path( + abs_path=str(f), + asset_hash="blake3:img-merge", + size_bytes=f.stat().st_size, + mtime_ns=1234567890000000001, + mime_type="image/png", + ) + + session.refresh(ref) + assert ref.system_metadata["kind"] == "image" + assert ref.system_metadata["width"] == 100 + assert ref.system_metadata["height"] == 200 + assert ref.system_metadata["source_url"] == "https://example/x.png" + + +class TestRegisterExistingAssetBackfill: + """The from-hash path back-fills dimensions from a sibling reference.""" + + def _add_reference( + self, + session: Session, + asset: Asset, + name: str, + system_metadata: dict | None = None, + ) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + asset_id=asset.id, + name=name, + owner_id="", + created_at=now, + updated_at=now, + last_access_time=now, + system_metadata=system_metadata or {}, + ) + session.add(ref) + session.flush() + return ref + + def test_backfills_dimensions_from_sibling_image_reference( + self, mock_create_session, session: Session + ): + asset = Asset(hash="blake3:shared", size_bytes=2048, mime_type="image/png") + session.add(asset) + session.flush() + self._add_reference( + session, + asset, + name="original.png", + system_metadata={"kind": "image", "width": 800, "height": 600}, + ) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:shared", + name="from_hash.png", + owner_id="user-x", + ) + + ref = session.query(AssetReference).filter_by(id=result.ref.id).first() + assert ref.system_metadata.get("kind") == "image" + assert ref.system_metadata.get("width") == 800 + assert ref.system_metadata.get("height") == 600 + + def test_no_backfill_when_sibling_has_no_image_metadata( + self, mock_create_session, session: Session + ): + asset = Asset(hash="blake3:nodims", size_bytes=2048, mime_type="image/png") + session.add(asset) + session.flush() + self._add_reference( + session, + asset, + name="original.png", + system_metadata={"base_model": "flux"}, # no kind=image + ) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:nodims", + name="from_hash.png", + owner_id="user-x", + ) + + ref = session.query(AssetReference).filter_by(id=result.ref.id).first() + meta = ref.system_metadata or {} + assert "kind" not in meta + assert "width" not in meta + assert "height" not in meta + + def test_no_backfill_when_no_sibling_exists( + self, mock_create_session, session: Session + ): + asset = Asset(hash="blake3:lonely", size_bytes=1024, mime_type="image/png") + session.add(asset) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:lonely", + name="solo.png", + owner_id="user-x", + ) + + ref = session.query(AssetReference).filter_by(id=result.ref.id).first() + assert ref.system_metadata in (None, {}) + + def test_backfill_preserves_caller_supplied_keys( + self, mock_create_session, session: Session + ): + asset = Asset(hash="blake3:preserve", size_bytes=2048, mime_type="image/png") + session.add(asset) + session.flush() + self._add_reference( + session, + asset, + name="original.png", + system_metadata={"kind": "image", "width": 1024, "height": 768}, + ) + session.commit() + + # Simulate a from-hash path where the new reference already carries + # some system_metadata (e.g. a download-provenance source_url written + # by an earlier step). The back-fill must merge dim keys without + # clobbering existing keys. + result = _register_existing_asset( + asset_hash="blake3:preserve", + name="from_hash.png", + owner_id="user-x", + ) + ref = session.query(AssetReference).filter_by(id=result.ref.id).first() + # Seed a sentinel key and re-run back-fill via a second register call + # to exercise the merge path with pre-existing data. + ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/p"} + session.commit() + + assert ref.system_metadata.get("source_url") == "https://example/p" + assert ref.system_metadata.get("kind") == "image" + assert ref.system_metadata.get("width") == 1024 + assert ref.system_metadata.get("height") == 768 From 00b633f368e68ffc229084ed819354c29006f92c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 8 Jun 2026 15:00:20 -0700 Subject: [PATCH 2/8] Revert "Add SeedVR2 support (CORE-6) (#14110)" (#14359) This reverts commit 7863cf0e53ca599a84b3ec5bcda122e4ecc3765c. --- comfy/latent_formats.py | 5 - comfy/ldm/modules/attention.py | 84 +- comfy/ldm/modules/diffusionmodules/model.py | 8 +- comfy/ldm/seedvr/color_fix.py | 340 --- comfy/ldm/seedvr/constants.py | 79 - comfy/ldm/seedvr/model.py | 1665 ------------- comfy/ldm/seedvr/vae.py | 2110 ----------------- comfy/model_base.py | 12 - comfy/model_detection.py | 50 - comfy/sample.py | 8 +- comfy/sd.py | 237 +- comfy/supported_models.py | 31 +- comfy/supported_models_base.py | 2 +- comfy_extras/nodes_seedvr.py | 1015 -------- nodes.py | 42 +- .../test_seedvr2_conditioning.py | 213 -- .../comfy_extras_test/test_seedvr2_nodes.py | 55 - .../test_seedvr2_post_processing.py | 57 - tests-unit/comfy_test/model_detection_test.py | 60 - .../comfy_test/seedvr_vae_forward_test.py | 90 - tests-unit/comfy_test/test_seedvr2_dtype.py | 47 - .../comfy_test/test_seedvr2_internals.py | 341 --- tests-unit/comfy_test/test_seedvr2_model.py | 308 --- .../comfy_test/test_seedvr2_vae_decode.py | 91 - .../comfy_test/test_seedvr2_vae_tiled.py | 347 --- .../test_seedvr_progressive_sampler.py | 126 - 26 files changed, 40 insertions(+), 7383 deletions(-) delete mode 100644 comfy/ldm/seedvr/color_fix.py delete mode 100644 comfy/ldm/seedvr/constants.py delete mode 100644 comfy/ldm/seedvr/model.py delete mode 100644 comfy/ldm/seedvr/vae.py delete mode 100644 comfy_extras/nodes_seedvr.py delete mode 100644 tests-unit/comfy_extras_test/test_seedvr2_conditioning.py delete mode 100644 tests-unit/comfy_extras_test/test_seedvr2_nodes.py delete mode 100644 tests-unit/comfy_extras_test/test_seedvr2_post_processing.py delete mode 100644 tests-unit/comfy_test/seedvr_vae_forward_test.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_dtype.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_internals.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_model.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_vae_decode.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_vae_tiled.py delete mode 100644 tests-unit/comfy_test/test_seedvr_progressive_sampler.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index fcbd97c59..bbdfd4bc2 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -4,7 +4,6 @@ class LatentFormat: scale_factor = 1.0 latent_channels = 4 latent_dimensions = 2 - preserve_empty_channel_multiples = False latent_rgb_factors = None latent_rgb_factors_bias = None latent_rgb_factors_reshape = None @@ -780,10 +779,6 @@ class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 -class SeedVR2(LatentFormat): - latent_channels = 16 - preserve_empty_channel_multiples = True - class ACEAudio15(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b78e764c7..55360535a 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -735,86 +735,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out -def _var_attention_qkv(q, k, v, heads, skip_reshape): - if skip_reshape: - return q, k, v, q.shape[-1] - total_tokens, embed_dim = q.shape - head_dim = embed_dim // heads - return ( - q.view(total_tokens, heads, head_dim), - k.view(k.shape[0], heads, head_dim), - v.view(v.shape[0], heads, head_dim), - head_dim, - ) - -def _var_attention_output(out, heads, head_dim, skip_output_reshape): - if skip_output_reshape: - return out - return out.reshape(-1, heads * head_dim) - - -def _use_blackwell_attention(): - device = model_management.get_torch_device() - if device.type != "cuda": - return False - major, minor = torch.cuda.get_device_capability(device) - return (major, minor) >= (12, 0) - - -def _validate_split_cu_seqlens(name, cu_seqlens, token_count): - if cu_seqlens.dtype not in (torch.int32, torch.int64): - raise ValueError(f"{name} must use an integer dtype") - if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2: - raise ValueError(f"{name} must be a 1D tensor with at least two offsets") - if cu_seqlens[0].item() != 0: - raise ValueError(f"{name} must start at 0") - if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item(): - raise ValueError(f"{name} must be strictly increasing") - if cu_seqlens[-1].item() != token_count: - raise ValueError(f"{name} does not match token count") - - -def _split_indices(cu_seqlens): - return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long) - - -def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): - q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) - - _validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0]) - _validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0]) - if cu_seqlens_k[-1].item() != v.shape[0]: - raise ValueError("cu_seqlens_k does not match v token count") - - q_split_indices = _split_indices(cu_seqlens_q) - k_split_indices = _split_indices(cu_seqlens_k) - q_splits = torch.tensor_split(q, q_split_indices, dim=0) - k_splits = torch.tensor_split(k, k_split_indices, dim=0) - v_splits = torch.tensor_split(v, k_split_indices, dim=0) - if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits): - raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count") - - out = [] - for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits): - q_i = q_i.permute(1, 0, 2).unsqueeze(0) - k_i = k_i.permute(1, 0, 2).unsqueeze(0) - v_i = v_i.permute(1, 0, 2).unsqueeze(0) - out_dtype = q_i.dtype - if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16): - q_i = q_i.to(torch.bfloat16) - k_i = k_i.to(torch.bfloat16) - v_i = v_i.to(torch.bfloat16) - out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True) - if out_i.dtype != out_dtype: - out_i = out_i.to(out_dtype) - out.append(out_i.squeeze(0).permute(1, 0, 2)) - - out = torch.cat(out, dim=0) - return _var_attention_output(out, heads, head_dim, skip_output_reshape) - - -optimized_var_attention = var_attention_optimized_split optimized_attention = attention_basic if model_management.sage_attention_enabled(): @@ -837,8 +758,6 @@ else: logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad -logging.info("Using optimized_attention split-loop for variable-length attention") - optimized_attention_masked = optimized_attention @@ -854,7 +773,6 @@ if model_management.xformers_enabled(): register_attention_function("pytorch", attention_pytorch) register_attention_function("sub_quad", attention_sub_quad) register_attention_function("split", attention_split) -register_attention_function("var_attention_optimized_split", var_attention_optimized_split) def optimized_attention_for_device(device, mask=False, small_input=False): @@ -1291,3 +1209,5 @@ class SpatialVideoTransformer(SpatialTransformer): x = self.proj_out(x) out = x + x_in return out + + diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 235df0b83..fcbaa074f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,7 +13,6 @@ if model_management.xformers_enabled_vae(): import xformers import xformers.ops - def torch_cat_if_needed(xl, dim): xl = [x for x in xl if x is not None and x.shape[dim] > 0] if len(xl) > 1: @@ -23,8 +22,7 @@ def torch_cat_if_needed(xl, dim): else: return None - -def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1): +def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. @@ -35,13 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - downscale_freq_shift) + emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if flip_sin_to_cos: - emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb diff --git a/comfy/ldm/seedvr/color_fix.py b/comfy/ldm/seedvr/color_fix.py deleted file mode 100644 index 7ddfc03af..000000000 --- a/comfy/ldm/seedvr/color_fix.py +++ /dev/null @@ -1,340 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import Tensor - -from comfy.ldm.seedvr.model import safe_pad_operation -from comfy.ldm.seedvr.vae import safe_interpolate_operation -from comfy.ldm.seedvr.constants import ( - CIELAB_DELTA, - CIELAB_KAPPA, - D65_WHITE_X, - D65_WHITE_Z, - WAVELET_DECOMP_LEVELS, -) - - -def wavelet_blur(image: Tensor, radius): - max_safe_radius = max(1, min(image.shape[-2:]) // 8) - if radius > max_safe_radius: - radius = max_safe_radius - - num_channels = image.shape[1] - - kernel_vals = [ - [0.0625, 0.125, 0.0625], - [0.125, 0.25, 0.125], - [0.0625, 0.125, 0.0625], - ] - kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) - kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) - - image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') - output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) - - return output - -def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS): - high_freq = torch.zeros_like(image) - - for i in range(levels): - radius = 2 ** i - low_freq = wavelet_blur(image, radius) - high_freq.add_(image).sub_(low_freq) - image = low_freq - - return high_freq, low_freq - -def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: - - if content_feat.shape != style_feat.shape: - # Resize style to match content spatial dimensions - if len(content_feat.shape) >= 3: - # safe_interpolate_operation handles FP16 conversion automatically - style_feat = safe_interpolate_operation( - style_feat, - size=content_feat.shape[-2:], - mode='bilinear', - align_corners=False - ) - - # Decompose both features into frequency components - content_high_freq, content_low_freq = wavelet_decomposition(content_feat) - del content_low_freq # Free memory immediately - - style_high_freq, style_low_freq = wavelet_decomposition(style_feat) - del style_high_freq # Free memory immediately - - if content_high_freq.shape != style_low_freq.shape: - style_low_freq = safe_interpolate_operation( - style_low_freq, - size=content_high_freq.shape[-2:], - mode='bilinear', - align_corners=False - ) - - content_high_freq.add_(style_low_freq) - - return content_high_freq.clamp_(-1.0, 1.0) - -def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor: - original_shape = source.shape - - # Flatten - source_flat = source.flatten() - reference_flat = reference.flatten() - - # Sort both arrays - source_sorted, source_indices = torch.sort(source_flat) - reference_sorted, _ = torch.sort(reference_flat) - del reference_flat - - # Quantile mapping - n_source = len(source_sorted) - n_reference = len(reference_sorted) - - if n_source == n_reference: - matched_sorted = reference_sorted - else: - # Interpolate reference to match source quantiles - source_quantiles = torch.linspace(0, 1, n_source, device=device) - ref_indices = (source_quantiles * (n_reference - 1)).long() - ref_indices.clamp_(0, n_reference - 1) - matched_sorted = reference_sorted[ref_indices] - del source_quantiles, ref_indices, reference_sorted - - del source_sorted, source_flat - - # Reconstruct using argsort (portable across CUDA/ROCm/MPS) - inverse_indices = torch.argsort(source_indices) - del source_indices - matched_flat = matched_sorted[inverse_indices] - del matched_sorted, inverse_indices - - return matched_flat.reshape(original_shape) - -def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: - """Convert batch of CIELAB images to RGB color space.""" - L, a, b = lab[:, 0], lab[:, 1], lab[:, 2] - - # LAB to XYZ - fy = (L + 16.0) / 116.0 - fx = a.div(500.0).add_(fy) - fz = fy - b / 200.0 - del L, a, b - - # XYZ transformation - x = torch.where( - fx > epsilon, - torch.pow(fx, 3.0), - fx.mul(116.0).sub_(16.0).div_(kappa) - ) - y = torch.where( - fy > epsilon, - torch.pow(fy, 3.0), - fy.mul(116.0).sub_(16.0).div_(kappa) - ) - z = torch.where( - fz > epsilon, - torch.pow(fz, 3.0), - fz.mul(116.0).sub_(16.0).div_(kappa) - ) - del fx, fy, fz - - # Apply D65 white point (in-place) - x.mul_(D65_WHITE_X) - # y *= 1.00000 # (no-op, skip) - z.mul_(D65_WHITE_Z) - - xyz = torch.stack([x, y, z], dim=1) - del x, y, z - - # Matrix multiplication: XYZ -> RGB - B, C, H, W = xyz.shape - xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3) - del xyz - - # Ensure dtype consistency for matrix multiplication - xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype) - rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T) - del xyz_flat - - rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) - del rgb_linear_flat - - # Apply inverse gamma correction (delinearize) - mask = rgb_linear > 0.0031308 - rgb = torch.where( - mask, - torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055), - rgb_linear * 12.92 - ) - del mask, rgb_linear - - return torch.clamp(rgb, 0.0, 1.0) - -def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: - """Convert batch of RGB images to CIELAB color space using D65 illuminant.""" - # Apply sRGB gamma correction (linearize) - mask = rgb > 0.04045 - rgb_linear = torch.where( - mask, - torch.pow((rgb + 0.055) / 1.055, 2.4), - rgb / 12.92 - ) - del mask - - # Matrix multiplication: RGB -> XYZ - B, C, H, W = rgb_linear.shape - rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3) - del rgb_linear - - # Ensure dtype consistency for matrix multiplication - rgb_flat = rgb_flat.to(dtype=matrix.dtype) - xyz_flat = torch.matmul(rgb_flat, matrix.T) - del rgb_flat - - xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) - del xyz_flat - - # Normalize by D65 white point (in-place) - xyz[:, 0].div_(D65_WHITE_X) # X - # xyz[:, 1] /= 1.00000 # Y (no-op, skip) - xyz[:, 2].div_(D65_WHITE_Z) # Z - - # XYZ to LAB transformation - epsilon_cubed = epsilon ** 3 - mask = xyz > epsilon_cubed - f_xyz = torch.where( - mask, - torch.pow(xyz, 1.0 / 3.0), - xyz.mul(kappa).add_(16.0).div_(116.0) - ) - del xyz, mask - - # Extract channels and compute LAB - L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100] - a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127] - b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127] - del f_xyz - - return torch.stack([L, a, b], dim=1) - -def lab_color_transfer( - content_feat: Tensor, - style_feat: Tensor, - luminance_weight: float = 0.8 -) -> Tensor: - content_feat = wavelet_reconstruction(content_feat, style_feat) - - if content_feat.shape != style_feat.shape: - style_feat = safe_interpolate_operation( - style_feat, - size=content_feat.shape[-2:], - mode='bilinear', - align_corners=False - ) - - device = content_feat.device - - def ensure_float32_precision(c): - orig_dtype = c.dtype - c = c.float() - return c, orig_dtype - content_feat, original_dtype = ensure_float32_precision(content_feat) - style_feat, _ = ensure_float32_precision(style_feat) - - rgb_to_xyz_matrix = torch.tensor([ - [0.4124564, 0.3575761, 0.1804375], - [0.2126729, 0.7151522, 0.0721750], - [0.0193339, 0.1191920, 0.9503041] - ], dtype=torch.float32, device=device) - - xyz_to_rgb_matrix = torch.tensor([ - [ 3.2404542, -1.5371385, -0.4985314], - [-0.9692660, 1.8760108, 0.0415560], - [ 0.0556434, -0.2040259, 1.0572252] - ], dtype=torch.float32, device=device) - - epsilon = CIELAB_DELTA - kappa = CIELAB_KAPPA - - content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) - style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) - - # Convert to LAB color space - content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa) - del content_feat - - style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa) - del style_feat, rgb_to_xyz_matrix - - # Match chrominance channels (a*, b*) for accurate color transfer - matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device) - matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device) - - # Handle luminance with weighted blending - if luminance_weight < 1.0: - # Partially match luminance for better overall color accuracy - matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device) - # Blend: preserve some content L* for detail, adopt some style L* for color - result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight)) - del matched_L - else: - # Fully preserve content luminance - result_L = content_lab[:, 0] - - del content_lab, style_lab - - # Reconstruct LAB with corrected channels - result_lab = torch.stack([result_L, matched_a, matched_b], dim=1) - del result_L, matched_a, matched_b - - # Convert back to RGB - result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa) - del result_lab, xyz_to_rgb_matrix - - # Convert back to [-1, 1] range (in-place) - result = result_rgb.mul_(2.0).sub_(1.0) - del result_rgb - - result = result.to(original_dtype) - - return result - - -def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor: - return wavelet_reconstruction(content_feat, style_feat) - - -def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor: - if content_feat.shape != style_feat.shape: - style_feat = safe_interpolate_operation( - style_feat, - size=content_feat.shape[-2:], - mode='bilinear', - align_corners=False, - ) - - original_dtype = content_feat.dtype - content_feat = content_feat.float() - style_feat = style_feat.float() - - b, c = content_feat.shape[:2] - content_flat = content_feat.reshape(b, c, -1) - style_flat = style_feat.reshape(b, c, -1) - - content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1) - content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) - style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1) - style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) - del content_flat, style_flat - - normalized = (content_feat - content_mean) / content_std - del content_mean, content_std - result = normalized * style_std + style_mean - del normalized, style_mean, style_std - - result = result.clamp_(-1.0, 1.0) - if result.dtype != original_dtype: - result = result.to(original_dtype) - return result diff --git a/comfy/ldm/seedvr/constants.py b/comfy/ldm/seedvr/constants.py deleted file mode 100644 index 95838d1dd..000000000 --- a/comfy/ldm/seedvr/constants.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Named constants for the SeedVR2 integration, grouped by provenance. - -Provenance prefixes: -- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline. -- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites - the upstream config/source path it was lifted from. -- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature / - ISO / CIE values; cite the standard. -""" - -# -------------------------------------------------------------------------------------- -# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment) -# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN) -# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070 -# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT). -# -------------------------------------------------------------------------------------- -SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB) -SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB - -# -------------------------------------------------------------------------------------- -# B. Fork heuristics (SEEDVR2 - this integration) -# -------------------------------------------------------------------------------------- -SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim. - # (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.) -SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry. -SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case). -SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM. -SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk. -SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels). -SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16). -SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset. - -# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing) -SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk. -SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path. -SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path. -SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. - -# -------------------------------------------------------------------------------------- -# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) -# -------------------------------------------------------------------------------------- -BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. -BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift. -BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem). -BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem). -BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28. -BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28. -BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16). -BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32). -BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t). -BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11. -BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size). -BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor. -BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor). -BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range. -BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))). -BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling). -BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames). -BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency). -BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim). -# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function. -BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242. -BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243. - -# -------------------------------------------------------------------------------------- -# D. Published standards (cite the literature) -# -------------------------------------------------------------------------------------- -ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864. - -# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65). -CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta). -CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa). -D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1). -D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn. -WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR). - -# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and -# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the -# exact existing coefficients move verbatim rather than being retyped here. diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py deleted file mode 100644 index 3fa9fe07e..000000000 --- a/comfy/ldm/seedvr/model.py +++ /dev/null @@ -1,1665 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Tuple, Union, List, Dict, Any, Callable -import einops -from einops import rearrange -import torch.nn.functional as F -from math import ceil, pi -import torch -from itertools import chain -from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding -from comfy.ldm.modules.attention import optimized_var_attention -from torch.nn.modules.utils import _triple -from torch import nn -import math -from comfy.ldm.flux.math import apply_rope1 -from comfy.ldm.seedvr.constants import ( - BYTEDANCE_720P_REF_AREA, - BYTEDANCE_MAX_TEMPORAL_WINDOW, - BYTEDANCE_ROPE_MAX_FREQ, - BYTEDANCE_SINUSOIDAL_DIM, - ROPE_THETA, - SEEDVR2_7B_MLP_CHUNK, - SEEDVR2_7B_VID_DIM, - SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, -) -import comfy.model_management -import numbers - -def _torch_float8_types(): - return tuple( - getattr(torch, name) - for name in ( - "float8_e4m3fn", - "float8_e4m3fnuz", - "float8_e5m2", - "float8_e5m2fnuz", - "float8_e8m0fnu", - ) - if hasattr(torch, name) - ) - -class CustomRMSNorm(nn.Module): - - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None): - super(CustomRMSNorm, self).__init__() - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = torch.Size(normalized_shape) - self.eps = eps - self.elementwise_affine = elementwise_affine - - if self.elementwise_affine: - self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype)) - else: - self.register_parameter('weight', None) - - def forward(self, input): - - dims = tuple(range(-len(self.normalized_shape), 0)) - - normalized = input.float() - variance = normalized.pow(2).mean(dim=dims, keepdim=True) - rms = torch.sqrt(variance + self.eps) - - normalized = normalized / rms - - if self.elementwise_affine: - return normalized * self.weight.to(input.dtype) - return normalized - -class Cache: - def __init__(self, disable=False, prefix="", cache=None): - self.cache = cache if cache is not None else {} - self.disable = disable - self.prefix = prefix - - def __call__(self, key: str, fn: Callable): - if self.disable: - return fn() - - key = self.prefix + key - try: - result = self.cache[key] - except KeyError: - result = fn() - self.cache[key] = result - return result - - def namespace(self, namespace: str): - return Cache( - disable=self.disable, - prefix=self.prefix + namespace + ".", - cache=self.cache, - ) - - def get(self, key: str): - key = self.prefix + key - return self.cache[key] - -def repeat_concat( - vid: torch.FloatTensor, # (VL ... c) - txt: torch.FloatTensor, # (TL ... c) - vid_len: torch.LongTensor, # (n*b) - txt_len: torch.LongTensor, # (b) - txt_repeat: List, # (n) -) -> torch.FloatTensor: # (L ... c) - vid = torch.split(vid, vid_len.tolist()) - txt = torch.split(txt, txt_len.tolist()) - txt = [[x] * n for x, n in zip(txt, txt_repeat)] - txt = list(chain(*txt)) - return torch.cat(list(chain(*zip(vid, txt)))) - -def concat( - vid: torch.FloatTensor, # (VL ... c) - txt: torch.FloatTensor, # (TL ... c) - vid_len: torch.LongTensor, # (b) - txt_len: torch.LongTensor, # (b) -) -> torch.FloatTensor: # (L ... c) - vid = torch.split(vid, vid_len.tolist()) - txt = torch.split(txt, txt_len.tolist()) - return torch.cat(list(chain(*zip(vid, txt)))) - -def concat_idx( - vid_len: torch.LongTensor, # (b) - txt_len: torch.LongTensor, # (b) -) -> Tuple[ - Callable, - Callable, -]: - device = vid_len.device - vid_idx = torch.arange(vid_len.sum(), device=device) - txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) - tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) - src_idx = torch.argsort(tgt_idx) - return ( - lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), - lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), - ) - - -def repeat_concat_idx( - vid_len: torch.LongTensor, # (n*b) - txt_len: torch.LongTensor, # (b) - txt_repeat: torch.LongTensor, # (n) -) -> Tuple[ - Callable, - Callable, -]: - device = vid_len.device - vid_idx = torch.arange(vid_len.sum(), device=device) - txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) - txt_repeat_list = txt_repeat.tolist() - tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) - src_idx = torch.argsort(tgt_idx) - txt_idx_len = len(tgt_idx) - len(vid_idx) - repeat_txt_len = (txt_len * txt_repeat).tolist() - - def unconcat_coalesce(all): - vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) - txt_out_coalesced = [] - for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): - txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) - txt_out_coalesced.append(txt) - return vid_out, torch.cat(txt_out_coalesced) - - return ( - lambda vid, txt: torch.cat([vid, txt])[tgt_idx], - lambda all: unconcat_coalesce(all), - ) - - -@dataclass -class MMArg: - vid: Any - txt: Any - -def safe_pad_operation(x, padding, mode='constant', value=0.0): - """Safe padding operation that handles Half precision only for problematic modes""" - # Modes qui nécessitent le fix Half precision - problematic_modes = ['replicate', 'reflect', 'circular'] - - if mode in problematic_modes: - try: - return F.pad(x, padding, mode=mode, value=value) - except RuntimeError as e: - if "not implemented for 'Half'" in str(e): - original_dtype = x.dtype - return F.pad(x.float(), padding, mode=mode, value=value).to(original_dtype) - else: - raise e - else: - # Pour 'constant' et autres modes compatibles, pas de fix nécessaire - return F.pad(x, padding, mode=mode, value=value) - - -def get_args(key: str, args: List[Any]) -> List[Any]: - return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] - - -def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: - return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} - - -def get_window_op(name: str): - if name == "720pwin_by_size_bysize": - return make_720Pwindows_bysize - if name == "720pswin_by_size_bysize": - return make_shifted_720Pwindows_bysize - raise ValueError(f"Unknown windowing method: {name}") - - -# -------------------------------- Windowing -------------------------------- # -def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): - t, h, w = size - resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p - scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) - resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. - nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. - return [ - ( - slice(it * wt, min((it + 1) * wt, t)), - slice(ih * wh, min((ih + 1) * wh, h)), - slice(iw * ww, min((iw + 1) * ww, w)), - ) - for iw in range(nw) - if min((iw + 1) * ww, w) > iw * ww - for ih in range(nh) - if min((ih + 1) * wh, h) > ih * wh - for it in range(nt) - if min((it + 1) * wt, t) > it * wt - ] - -def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): - t, h, w = size - resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p - scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) - resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. - - st, sh, sw = ( # shift size. - 0.5 if wt < t else 0, - 0.5 if wh < h else 0, - 0.5 if ww < w else 0, - ) - nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. - nt, nh, nw = ( # number of window. - nt + 1 if st > 0 else 1, - nh + 1 if sh > 0 else 1, - nw + 1 if sw > 0 else 1, - ) - return [ - ( - slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), - slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), - slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), - ) - for iw in range(nw) - if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) - for ih in range(nh) - if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) - for it in range(nt) - if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) - ] - -class RotaryEmbedding(nn.Module): - def __init__( - self, - dim, - custom_freqs = None, - freqs_for = 'lang', - theta = 10000, - max_freq = 10, - num_freqs = 1, - learned_freq = False, - use_xpos = False, - xpos_scale_base = 512, - interpolate_factor = 1., - theta_rescale_factor = 1., - seq_before_head_dim = False, - cache_if_possible = True, - cache_max_seq_len = 8192 - ): - super().__init__() - - theta *= theta_rescale_factor ** (dim / (dim - 2)) - - self.freqs_for = freqs_for - - if exists(custom_freqs): - freqs = custom_freqs - elif freqs_for == 'lang': - freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) - elif freqs_for == 'pixel': - freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi - elif freqs_for == 'constant': - freqs = torch.ones(num_freqs).float() - - self.cache_if_possible = cache_if_possible - self.cache_max_seq_len = cache_max_seq_len - - self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) - self.cached_freqs_seq_len = 0 - - self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) - - self.learned_freq = learned_freq - - # dummy for device - - self.register_buffer('dummy', torch.tensor(0), persistent = False) - - # default sequence dimension - - self.seq_before_head_dim = seq_before_head_dim - self.default_seq_dim = -3 if seq_before_head_dim else -2 - - # interpolation factors - - assert interpolate_factor >= 1. - self.interpolate_factor = interpolate_factor - - # xpos - - self.use_xpos = use_xpos - - if not use_xpos: - return - - scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) - self.scale_base = xpos_scale_base - - self.register_buffer('scale', scale, persistent = False) - self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) - self.cached_scales_seq_len = 0 - - # add apply_rotary_emb as static method - - self.apply_rotary_emb = staticmethod(apply_rotary_emb) - - @property - def device(self): - return self.dummy.device - - def get_axial_freqs( - self, - *dims, - offsets = None - ): - Colon = slice(None) - all_freqs = [] - - # handle offset - - if exists(offsets): - assert len(offsets) == len(dims) - - for ind, dim in enumerate(dims): - - offset = 0 - if exists(offsets): - offset = offsets[ind] - - if self.freqs_for == 'pixel': - pos = torch.linspace(-1, 1, steps = dim, device = self.device) - else: - pos = torch.arange(dim, device = self.device) - - pos = pos + offset - - freqs = self.forward(pos, seq_len = dim) - - all_axis = [None] * len(dims) - all_axis[ind] = Colon - - new_axis_slice = (Ellipsis, *all_axis, Colon) - all_freqs.append(freqs[new_axis_slice]) - - # concat all freqs - - all_freqs = torch.broadcast_tensors(*all_freqs) - return torch.cat(all_freqs, dim = -1) - - def forward( - self, - t, - seq_len: int | None = None, - offset = 0 - ): - should_cache = ( - self.cache_if_possible and - not self.learned_freq and - exists(seq_len) and - self.freqs_for != 'pixel' and - (offset + seq_len) <= self.cache_max_seq_len - ) - - if ( - should_cache and \ - exists(self.cached_freqs) and \ - (offset + seq_len) <= self.cached_freqs_seq_len - ): - return self.cached_freqs[offset:(offset + seq_len)].detach() - - freqs = self.freqs - - freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs) - freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2) - - if should_cache and offset == 0: - self.cached_freqs[:seq_len] = freqs.detach() - self.cached_freqs_seq_len = seq_len - - return freqs - -class RotaryEmbeddingBase(nn.Module): - def __init__(self, dim: int, rope_dim: int): - super().__init__() - self.rope = RotaryEmbedding( - dim=dim // rope_dim, - freqs_for="pixel", - max_freq=BYTEDANCE_ROPE_MAX_FREQ, - ) - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) - - def get_axial_freqs(self, *dims): - return self.rope.get_axial_freqs(*dims) - - -class RotaryEmbedding3d(RotaryEmbeddingBase): - def __init__(self, dim: int): - super().__init__(dim, rope_dim=3) - self.mm = False - - def forward( - self, - q: torch.FloatTensor, # b h l d - k: torch.FloatTensor, # b h l d - size: Tuple[int, int, int], - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - T, H, W = size - freqs = self.get_axial_freqs(T, H, W) - q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) - k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) - q = apply_rotary_emb(freqs, q.float()).to(q.dtype) - k = apply_rotary_emb(freqs, k.float()).to(k.dtype) - q = rearrange(q, "b h T H W d -> b h (T H W) d") - k = rearrange(k, "b h T H W d -> b h (T H W) d") - return q, k - - -class NaRotaryEmbedding3d(RotaryEmbedding3d): - def forward( - self, - q: torch.FloatTensor, - k: torch.FloatTensor, - shape: torch.LongTensor, - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape)) - freqs = freqs.to(device=q.device) - q = rearrange(q, "L h d -> h L d") - k = rearrange(k, "L h d -> h L d") - q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype) - k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype) - q = rearrange(q, "h L d -> L h d") - k = rearrange(k, "h L d -> L h d") - return q, k - - @torch._dynamo.disable - def get_freqs( - self, - shape: torch.LongTensor, - ) -> torch.Tensor: - # Primary provenance: ByteDance-Seed/SeedVR models/dit/rope.py builds - # 7B pixel RoPE with the interleaved-angle convention, not Comfy's - # Flux freqs_cis matrix. - plain_rope = RotaryEmbedding( - dim=self.rope.freqs.numel() * 2, - freqs_for="pixel", - max_freq=BYTEDANCE_ROPE_MAX_FREQ, - ) - plain_rope = plain_rope.to(self.rope.dummy.device) - freq_list = [] - for f, h, w in shape.tolist(): - freqs = plain_rope.get_axial_freqs(f, h, w) - freq_list.append(freqs.view(-1, freqs.size(-1))) - return torch.cat(freq_list, dim=0) - - -class MMRotaryEmbeddingBase(RotaryEmbeddingBase): - def __init__(self, dim: int, rope_dim: int): - super().__init__(dim, rope_dim) - self.rope = RotaryEmbedding( - dim=dim // rope_dim, - freqs_for="lang", - theta=ROPE_THETA, - cache_if_possible=False, - ) - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) - self.mm = True - -def slice_at_dim(t, dim_slice: slice, *, dim): - dim += (t.ndim if dim < 0 else 0) - colons = [slice(None)] * t.ndim - colons[dim] = dim_slice - return t[tuple(colons)] - -# rotary embedding helper functions - -def rotate_half(x): - x = rearrange(x, '... (d r) -> ... d r', r = 2) - x1, x2 = x.unbind(dim = -1) - x = torch.stack((-x2, x1), dim = -1) - return rearrange(x, '... d r -> ... (d r)') -def exists(val): - return val is not None - -def apply_rotary_emb( - freqs, - t, - start_index = 0, - scale = 1., - seq_dim = -2, - freqs_seq_dim = None -): - dtype = t.dtype - if not exists(freqs_seq_dim): - if freqs.ndim == 2 or t.ndim == 3: - freqs_seq_dim = 0 - - if t.ndim == 3 or exists(freqs_seq_dim): - seq_len = t.shape[seq_dim] - freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) - - rot_feats = freqs.shape[-1] - end_index = start_index + rot_feats - - t_left = t[..., :start_index] - t_middle = t[..., start_index:end_index] - t_right = t[..., end_index:] - - angles = freqs.to(t_middle.device)[..., ::2] - cos = torch.cos(angles) * scale - sin = torch.sin(angles) * scale - - col0 = torch.stack([cos, sin], dim=-1) - col1 = torch.stack([-sin, cos], dim=-1) - freqs_mat = torch.stack([col0, col1], dim=-1) - - t_middle_out = apply_rope1(t_middle, freqs_mat) - out = torch.cat((t_left, t_middle_out, t_right), dim=-1) - return out.type(dtype) - - -def _apply_seedvr2_rotary_emb( - freqs: torch.Tensor, - t: torch.Tensor, - start_index: int = 0, - scale: float = 1.0, - seq_dim: int = -2, - freqs_seq_dim: int | None = None, -) -> torch.Tensor: - dtype = t.dtype - if freqs_seq_dim is None and (freqs.ndim == 2 or t.ndim == 3): - freqs_seq_dim = 0 - - if t.ndim == 3 or freqs_seq_dim is not None: - seq_len = t.shape[seq_dim] - freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim) - - rot_feats = freqs.shape[-1] - end_index = start_index + rot_feats - - t_left = t[..., :start_index] - t_middle = t[..., start_index:end_index] - t_right = t[..., end_index:] - - freqs = freqs.to(device=t_middle.device, dtype=t_middle.dtype) - cos = freqs.cos() * scale - sin = freqs.sin() * scale - t_middle = (t_middle * cos) + (rotate_half(t_middle) * sin) - return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype) - -def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: - """Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`.""" - angles = freqs_interleaved[..., ::2].float() - cos = torch.cos(angles) - sin = torch.sin(angles) - out = torch.stack([cos, -sin, sin, cos], dim=-1) - return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2) - - -def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - """Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest - through; in-place for inference, cloned for training (autograd). Mirrors the legacy - ``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives - ``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when - ``rot_d == t.shape[-1]``. - """ - out = t.clone() if t.requires_grad or comfy.model_management.in_training else t - rot_d = 2 * freqs_cis.shape[-3] - seq_len = out.shape[-2] - for start in range(0, seq_len, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS): - end = min(start + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, seq_len) - freqs_chunk = freqs_cis[start:end] - if rot_d == out.shape[-1]: - out[..., start:end, :] = apply_rope1(out[..., start:end, :], freqs_chunk).to(out.dtype) - else: - out[..., start:end, :rot_d] = apply_rope1(out[..., start:end, :rot_d], freqs_chunk).to(out.dtype) - return out - - -class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): - def __init__(self, dim: int): - super().__init__(dim, rope_dim=3) - - def forward( - self, - vid_q: torch.FloatTensor, # L h d - vid_k: torch.FloatTensor, # L h d - vid_shape: torch.LongTensor, # B 3 - txt_q: torch.FloatTensor, # L h d - txt_k: torch.FloatTensor, # L h d - txt_shape: torch.LongTensor, # B 1 - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_freqs, txt_freqs = cache( - "mmrope_freqs_3d", - lambda: self.get_freqs(vid_shape, txt_shape), - ) - target_device = vid_q.device - if vid_freqs.device != target_device: - vid_freqs = vid_freqs.to(target_device) - if txt_freqs.device != target_device: - txt_freqs = txt_freqs.to(target_device) - vid_q = rearrange(vid_q, "L h d -> h L d") - vid_k = rearrange(vid_k, "L h d -> h L d") - vid_q = _apply_rope1_partial(vid_q, vid_freqs) - vid_k = _apply_rope1_partial(vid_k, vid_freqs) - vid_q = rearrange(vid_q, "h L d -> L h d") - vid_k = rearrange(vid_k, "h L d -> L h d") - - txt_q = rearrange(txt_q, "L h d -> h L d") - txt_k = rearrange(txt_k, "L h d -> h L d") - txt_q = _apply_rope1_partial(txt_q, txt_freqs) - txt_k = _apply_rope1_partial(txt_k, txt_freqs) - txt_q = rearrange(txt_q, "h L d -> L h d") - txt_k = rearrange(txt_k, "h L d -> L h d") - return vid_q, vid_k, txt_q, txt_k - - @torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks - def get_freqs( - self, - vid_shape: torch.LongTensor, - txt_shape: torch.LongTensor, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - ]: - - # Calculate actual max dimensions needed for this batch - max_temporal = 0 - max_height = 0 - max_width = 0 - max_txt_len = 0 - - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal - max_height = max(max_height, h) - max_width = max(max_width, w) - max_txt_len = max(max_txt_len, l) - - autocast_device = "cuda" if torch.cuda.is_available() else "cpu" - with torch.amp.autocast(autocast_device, enabled=False): - vid_freqs = self.get_axial_freqs( - max_temporal + 16, - max_height + 4, - max_width + 4, - ).float() - txt_freqs = self.get_axial_freqs(max_txt_len + 16) - - # Now slice as before - vid_freq_list, txt_freq_list = [], [] - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) - txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) - vid_freq_list.append(vid_freq) - txt_freq_list.append(txt_freq) - vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0) - txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0) - - # Convert from lucidrains-interleaved layout `[θ0, θ0, θ1, θ1, ...]` - # (produced by `repeat(freqs, '... n -> ... (n r)', r=2)` in the - # upstream `RotaryEmbedding.forward`) to flux-canonical `freqs_cis` - # in shape `[..., d/2, 2, 2]` with `cos/-sin/sin/cos` baked in. - # Mirrors `comfy/ldm/flux/math.py:rope` (line 27) so the trailing - # 2x2 is the per-frequency rotation matrix that - # `comfy.ldm.flux.math.apply_rope1` expects. - return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved) - -class MMModule(nn.Module): - def __init__( - self, - module: Callable[..., nn.Module], - *args, - shared_weights: bool = False, - vid_only: bool = False, - **kwargs, - ): - super().__init__() - self.shared_weights = shared_weights - self.vid_only = vid_only - if self.shared_weights: - assert get_args("vid", args) == get_args("txt", args) - assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) - self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) - else: - self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) - self.txt = ( - module(*get_args("txt", args), **get_kwargs("txt", kwargs)) - if not vid_only - else None - ) - - def forward( - self, - vid: torch.FloatTensor, - txt: torch.FloatTensor, - *args, - **kwargs, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_module = self.vid if not self.shared_weights else self.all - vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) - if not self.vid_only: - txt_module = self.txt if not self.shared_weights else self.all - txt = txt.to(device=vid.device, dtype=vid.dtype) - txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) - return vid, txt - -def get_na_rope(rope_type: Optional[str], dim: int): - if rope_type is None: - return None - if rope_type == "rope3d": - return NaRotaryEmbedding3d(dim=dim) - if rope_type == "mmrope3d": - return NaMMRotaryEmbedding3d(dim=dim) - -class NaMMAttention(nn.Module): - def __init__( - self, - vid_dim: int, - txt_dim: int, - heads: int, - head_dim: int, - qk_bias: bool, - qk_norm, - qk_norm_eps: float, - rope_type: Optional[str], - rope_dim: int, - shared_weights: bool, - device, dtype, operations, - **kwargs, - ): - super().__init__() - dim = MMArg(vid_dim, txt_dim) - self.heads = heads - inner_dim = heads * head_dim - qkv_dim = inner_dim * 3 - self.head_dim = head_dim - self.proj_qkv = MMModule( - operations.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights, device=device, dtype=dtype - ) - self.proj_out = MMModule(operations.Linear, inner_dim, dim, shared_weights=shared_weights, device=device, dtype=dtype) - self.norm_q = MMModule( - qk_norm, - normalized_shape=head_dim, - eps=qk_norm_eps, - elementwise_affine=True, - shared_weights=shared_weights, - device=device, dtype=dtype - ) - self.norm_k = MMModule( - qk_norm, - normalized_shape=head_dim, - eps=qk_norm_eps, - elementwise_affine=True, - shared_weights=shared_weights, - device=device, dtype=dtype - ) - - - self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) - - def forward(self): - pass - -def window( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - window_fn: Callable[[torch.Tensor], List[torch.Tensor]], -): - hid = unflatten(hid, hid_shape) - hid = list(map(window_fn, hid)) - hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) - hid, hid_shape = flatten(list(chain(*hid))) - return hid, hid_shape, hid_windows - -def window_idx( - hid_shape: torch.LongTensor, # (b n) - window_fn: Callable[[torch.Tensor], List[torch.Tensor]], -): - hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) - tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) - tgt_idx = tgt_idx.squeeze(-1) - src_idx = torch.argsort(tgt_idx) - return ( - lambda hid: torch.index_select(hid, 0, tgt_idx), - lambda hid: torch.index_select(hid, 0, src_idx), - tgt_shape, - tgt_windows, - ) - -class NaSwinAttention(NaMMAttention): - def __init__( - self, - *args, - window: Union[int, Tuple[int, int, int]], - window_method: bool, # shifted or not - **kwargs, - ): - super().__init__(*args, **kwargs) - self.version_7b = kwargs.get("version", False) - self.window = _triple(window) - self.window_method = window_method - assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) - - self.window_op = get_window_op(window_method) - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - - vid_qkv, txt_qkv = self.proj_qkv(vid, txt) - - # re-org the input seq for window attn - cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") - - def make_window(x: torch.Tensor): - t, h, w, _ = x.shape - window_slices = self.window_op((t, h, w), self.window) - return [x[st, sh, sw] for (st, sh, sw) in window_slices] - - window_partition, window_reverse, window_shape, window_count = cache_win( - "win_transform", - lambda: window_idx(vid_shape, make_window), - ) - vid_qkv_win = window_partition(vid_qkv) - - vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) - txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) - - vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) - txt_q, txt_k, txt_v = txt_qkv.unbind(1) - - vid_q, txt_q = self.norm_q(vid_q, txt_q) - vid_k, txt_k = self.norm_k(vid_k, txt_k) - - txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) - - vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) - txt_len = txt_len.to(window_count.device) - - # window rope - if self.rope: - if self.version_7b: - vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - elif self.rope.mm: - # repeat text q and k for window mmrope - _, num_h, _ = txt_q.shape - txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") - txt_q_repeat = unflatten(txt_q_repeat, txt_shape) - txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] - txt_q_repeat = list(chain(*txt_q_repeat)) - txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) - txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) - - txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") - txt_k_repeat = unflatten(txt_k_repeat, txt_shape) - txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] - txt_k_repeat = list(chain(*txt_k_repeat)) - txt_k_repeat, _ = flatten(txt_k_repeat) - txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) - - vid_q, vid_k, txt_q, txt_k = self.rope( - vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win - ) - else: - vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - - txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) - all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) - concat_win, unconcat_win = cache_win( - "mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count) - ) - out = optimized_var_attention( - q=concat_win(vid_q, txt_q), - k=concat_win(vid_k, txt_k), - v=concat_win(vid_v, txt_v), - heads=self.heads, skip_reshape=True, skip_output_reshape=True, - cu_seqlens_q=cache_win( - "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() - ), - cu_seqlens_k=cache_win( - "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() - ), - ) - vid_out, txt_out = unconcat_win(out) - - vid_out = rearrange(vid_out, "l h d -> l (h d)") - txt_out = rearrange(txt_out, "l h d -> l (h d)") - vid_out = window_reverse(vid_out) - - vid_out, txt_out = self.proj_out(vid_out, txt_out) - - return vid_out, txt_out - -class MLP(nn.Module): - def __init__( - self, - dim: int, - expand_ratio: int, - device, dtype, operations - ): - super().__init__() - self.proj_in = operations.Linear(dim, dim * expand_ratio, device=device, dtype=dtype) - self.act = nn.GELU("tanh") - self.proj_out = operations.Linear(dim * expand_ratio, dim, device=device, dtype=dtype) - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - x = self.proj_in(x) - x = self.act(x) - x = self.proj_out(x) - return x - - -class SwiGLUMLP(nn.Module): - def __init__( - self, - dim: int, - expand_ratio: int, - multiple_of: int = 256, - device=None, dtype=None, operations=None - ): - super().__init__() - hidden_dim = int(2 * dim * expand_ratio / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.proj_in_gate = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) - self.proj_out = operations.Linear(hidden_dim, dim, bias=False, device=device, dtype=dtype) - self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) - -def get_mlp(mlp_type: Optional[str] = "normal"): - # 3b and 7b uses different mlp types - if mlp_type == "normal": - return MLP - elif mlp_type == "swiglu": - return SwiGLUMLP - -class NaMMSRTransformerBlock(nn.Module): - def __init__( - self, - *, - vid_dim: int, - txt_dim: int, - emb_dim: int, - heads: int, - head_dim: int, - expand_ratio: int, - norm, - norm_eps: float, - ada, - qk_bias: bool, - qk_norm, - mlp_type: str, - shared_weights: bool, - rope_type: str, - rope_dim: int, - is_last_layer: bool, - device, dtype, operations, - **kwargs, - ): - super().__init__() - version = kwargs.get("version", False) - dim = MMArg(vid_dim, txt_dim) - self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype) - - self.attn = NaSwinAttention( - vid_dim=vid_dim, - txt_dim=txt_dim, - heads=heads, - head_dim=head_dim, - qk_bias=qk_bias, - qk_norm=qk_norm, - qk_norm_eps=norm_eps, - rope_type=rope_type, - rope_dim=rope_dim, - shared_weights=shared_weights, - window=kwargs.pop("window", None), - window_method=kwargs.pop("window_method", None), - version=version, - device=device, dtype=dtype, operations=operations - ) - - self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) - self.mlp = MMModule( - get_mlp(mlp_type), - dim=dim, - expand_ratio=expand_ratio, - shared_weights=shared_weights, - vid_only=is_last_layer, - device=device, dtype=dtype, operations=operations - ) - self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) - self.is_last_layer = is_last_layer - self.version = version - - def _seedvr2_7b_mlp( - self, - vid: torch.FloatTensor, - txt: torch.FloatTensor, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_module = self.mlp.vid if not self.mlp.shared_weights else self.mlp.all - if comfy.model_management.in_training or vid.requires_grad: - vid = torch.cat([vid_module(chunk) for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0)], dim=0) - else: - vid_out = None - offset = 0 - for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0): - chunk_out = vid_module(chunk) - if vid_out is None: - vid_out = chunk_out.new_empty((vid.shape[0], *chunk_out.shape[1:])) - vid_out[offset:offset + chunk_out.shape[0]] = chunk_out - offset += chunk_out.shape[0] - vid = vid_out - if not self.mlp.vid_only: - txt_module = self.mlp.txt if not self.mlp.shared_weights else self.mlp.all - txt = txt.to(device=vid.device, dtype=vid.dtype) - txt = txt_module(txt) - return vid, txt - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - emb: torch.FloatTensor, - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - torch.LongTensor, - torch.LongTensor, - ]: - hid_len = MMArg( - cache("vid_len", lambda: vid_shape.prod(-1)), - cache("txt_len", lambda: txt_shape.prod(-1)), - ) - ada_kwargs = { - "emb": emb, - "hid_len": hid_len, - "cache": cache, - "branch_tag": MMArg("vid", "txt"), - } - - vid_attn, txt_attn = self.attn_norm(vid, txt) - vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) - vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) - vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) - vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) - - vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) - vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) - if self.version: - vid_mlp, txt_mlp = self._seedvr2_7b_mlp(vid_mlp, txt_mlp) - else: - vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) - vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) - vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) - - return vid_mlp, txt_mlp, vid_shape, txt_shape - -class PatchOut(nn.Module): - def __init__( - self, - out_channels: int, - patch_size: Union[int, Tuple[int, int, int]], - dim: int, - device, dtype, operations - ): - super().__init__() - t, h, w = _triple(patch_size) - self.patch_size = t, h, w - self.proj = operations.Linear(dim, out_channels * t * h * w, device=device, dtype=dtype) - - def forward( - self, - vid: torch.Tensor, - ) -> torch.Tensor: - t, h, w = self.patch_size - vid = self.proj(vid) - vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) - if t > 1: - vid = vid[:, :, (t - 1) :] - return vid - -class NaPatchOut(PatchOut): - def forward( - self, - vid: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), # for test - vid_shape_before_patchify = None - ) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, - ]: - - t, h, w = self.patch_size - vid = self.proj(vid) - - if not (t == h == w == 1): - vid = unflatten(vid, vid_shape) - for i in range(len(vid)): - vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) - if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: - vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] - vid, vid_shape = flatten(vid) - - return vid, vid_shape - -class PatchIn(nn.Module): - def __init__( - self, - in_channels: int, - patch_size: Union[int, Tuple[int, int, int]], - dim: int, - device, dtype, operations - ): - super().__init__() - t, h, w = _triple(patch_size) - self.patch_size = t, h, w - self.proj = operations.Linear(in_channels * t * h * w, dim, device=device, dtype=dtype) - - def forward( - self, - vid: torch.Tensor, - ) -> torch.Tensor: - t, h, w = self.patch_size - if t > 1: - assert vid.size(2) % t == 1 - vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) - vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) - vid = self.proj(vid) - return vid - -class NaPatchIn(PatchIn): - def forward( - self, - vid: torch.Tensor, # l c - vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), # for test - ) -> torch.Tensor: - cache = cache.namespace("patch") - vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) - t, h, w = self.patch_size - if not (t == h == w == 1): - vid = unflatten(vid, vid_shape) - for i in range(len(vid)): - if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: - vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) - vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) - vid, vid_shape = flatten(vid) - - vid = self.proj(vid) - return vid, vid_shape - -def expand_dims(x: torch.Tensor, dim: int, ndim: int): - shape = x.shape - shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] - return x.reshape(shape) - - -class AdaSingle(nn.Module): - def __init__( - self, - dim: int, - emb_dim: int, - layers: List[str], - modes: List[str] = ["in", "out"], - device = None, dtype = None, - ): - assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" - super().__init__() - self.dim = dim - self.emb_dim = emb_dim - self.layers = layers - - randn_kwargs = {"device": device} - fp8_types = _torch_float8_types() - if dtype is not None and dtype not in fp8_types: - randn_kwargs["dtype"] = dtype - - for l in layers: - if "in" in modes: - # Passing fp8 ``dtype=`` here would break CPU weight - # loads: CPU has no ``normal_kernel_cpu`` for fp8. - self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5)) - self.register_parameter( - f"{l}_scale", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5 + 1) - ) - if "out" in modes: - self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5)) - - def forward( - self, - hid: torch.FloatTensor, # b ... c - emb: torch.FloatTensor, # b d - layer: str, - mode: str, - cache: Cache = Cache(disable=True), - branch_tag: str = "", - hid_len: Optional[torch.LongTensor] = None, # b - ) -> torch.FloatTensor: - idx = self.layers.index(layer) - emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] - emb = expand_dims(emb, 1, hid.ndim + 1) - - if hid_len is not None: - slice_inputs = lambda x, dim: x - emb = cache( - f"emb_repeat_{idx}_{branch_tag}", - lambda: slice_inputs( - torch.repeat_interleave(emb, hid_len, dim=0), - dim=0, - ), - ) - - shiftA, scaleA, gateA = emb.unbind(-1) - shiftB, scaleB, gateB = ( - getattr(self, f"{layer}_shift", None), - getattr(self, f"{layer}_scale", None), - getattr(self, f"{layer}_gate", None), - ) - - fp8_types = _torch_float8_types() - if fp8_types: - target_dtype = hid.dtype - - if shiftB is not None and shiftB.dtype in fp8_types: - shiftB = shiftB.to(target_dtype) - if scaleB is not None and scaleB.dtype in fp8_types: - scaleB = scaleB.to(target_dtype) - if gateB is not None and gateB.dtype in fp8_types: - gateB = gateB.to(target_dtype) - - if mode == "in": - return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) - if mode == "out": - if gateB is not None: - return hid.mul_(gateA + gateB) - else: - return hid.mul_(gateA) - - raise NotImplementedError - - -def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): - return emb1 if emb2 is None else emb1 + emb2 - - -class TimeEmbedding(nn.Module): - def __init__( - self, - sinusoidal_dim: int, - hidden_dim: int, - output_dim: int, - device, dtype, operations - ): - super().__init__() - self.sinusoidal_dim = sinusoidal_dim - self.proj_in = operations.Linear(sinusoidal_dim, hidden_dim, device=device, dtype=dtype) - self.proj_hid = operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype) - self.proj_out = operations.Linear(hidden_dim, output_dim, device=device, dtype=dtype) - self.act = nn.SiLU() - - def forward( - self, - timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], - device: torch.device, - dtype: torch.dtype, - ) -> torch.FloatTensor: - if not torch.is_tensor(timestep): - timestep = torch.tensor([timestep], device=device, dtype=dtype) - if timestep.ndim == 0: - timestep = timestep[None] - - emb = get_timestep_embedding( - timesteps=timestep, - embedding_dim=self.sinusoidal_dim, - flip_sin_to_cos=False, - downscale_freq_shift=0, - ).to(dtype) - emb = self.proj_in(emb) - emb = self.act(emb) - emb = self.proj_hid(emb) - emb = self.act(emb) - emb = self.proj_out(emb) - return emb - -def flatten( - hid: List[torch.FloatTensor], # List of (*** c) -) -> Tuple[ - torch.FloatTensor, # (L c) - torch.LongTensor, # (b n) -]: - assert len(hid) > 0 - shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) - hid = torch.cat([x.flatten(0, -2) for x in hid]) - return hid, shape - - -def unflatten( - hid: torch.FloatTensor, # (L c) or (L ... c) - hid_shape: torch.LongTensor, # (b n) -) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) - hid_len = hid_shape.prod(-1) - hid = hid.split(hid_len.tolist()) - hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] - return hid - -def repeat( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - pattern: str, - **kwargs: Dict[str, torch.LongTensor], # (b) -) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, -]: - hid = unflatten(hid, hid_shape) - kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] - return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) - -class NaDiT(nn.Module): - - def __init__( - self, - norm_eps, - qk_rope, - num_layers, - mlp_type, - vid_in_channels = 33, - vid_out_channels = 16, - vid_dim = 2560, - txt_in_dim = 5120, - heads = 20, - head_dim = 128, - mm_layers = 10, - expand_ratio = 4, - qk_bias = False, - patch_size = [ 1,2,2 ], - shared_qkv: bool = False, - shared_mlp: bool = False, - window_method: Optional[Tuple[str]] = None, - temporal_window_size: int = None, - temporal_shifted: bool = False, - rope_dim = 128, - rope_type = "mmrope3d", - vid_out_norm: Optional[str] = None, - device = None, - dtype = None, - operations = None, - **kwargs, - ): - self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM - if self._7b_version: - rope_type = "rope3d" - self.dtype = dtype - factory_kwargs = {"device": device, "dtype": dtype} - window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] - txt_dim = vid_dim - emb_dim = vid_dim * 6 - block_type = ["mmdit_sr"] * num_layers - window = num_layers * [(4,3,3)] - ada = AdaSingle - norm = CustomRMSNorm - qk_norm = CustomRMSNorm - if isinstance(block_type, str): - block_type = [block_type] * num_layers - elif len(block_type) != num_layers: - raise ValueError("The ``block_type`` list should equal to ``num_layers``.") - super().__init__() - # ``torch.empty`` returns uninitialized memory, not zeros. The - # SeedVR2Conditioning fail-loud guard at - # ``comfy_extras/nodes_seedvr.py`` distinguishes "buffer was loaded" - # from "buffer was never populated by the file" by checking - # ``positive_conditioning.abs().sum() == 0``. That sentinel is only - # reliable if the post-construction buffer state is deterministically - # zero, so explicitly zero-fill here rather than relying on the - # allocator's zero-on-alloc behavior (allocator-dependent and not - # contractual). When ``load_state_dict`` populates these buffers - # from a properly-baked SeedVR2 .safetensors, the in-place copy - # overwrites the zeros with the universal SeedVR2 conditioning - # tensors (shape (58, 5120) and (64, 5120) bf16). - self.register_buffer("positive_conditioning", torch.zeros((58, 5120), device=device, dtype=dtype)) - self.register_buffer("negative_conditioning", torch.zeros((64, 5120), device=device, dtype=dtype)) - self.vid_in = NaPatchIn( - in_channels=vid_in_channels, - patch_size=patch_size, - dim=vid_dim, - device=device, dtype=dtype, operations=operations - ) - self.txt_in = ( - operations.Linear(txt_in_dim, txt_dim, **factory_kwargs) - if txt_in_dim and txt_in_dim != txt_dim - else nn.Identity() - ) - self.emb_in = TimeEmbedding( - sinusoidal_dim=BYTEDANCE_SINUSOIDAL_DIM, - hidden_dim=max(vid_dim, txt_dim), - output_dim=emb_dim, - device=device, dtype=dtype, operations=operations - ) - - if window is None or isinstance(window[0], int): - window = [window] * num_layers - if window_method is None or isinstance(window_method, str): - window_method = [window_method] * num_layers - if temporal_window_size is None or isinstance(temporal_window_size, int): - temporal_window_size = [temporal_window_size] * num_layers - if temporal_shifted is None or isinstance(temporal_shifted, bool): - temporal_shifted = [temporal_shifted] * num_layers - - rope_dim = rope_dim if rope_dim is not None else head_dim // 2 - self.blocks = nn.ModuleList( - [ - NaMMSRTransformerBlock( - vid_dim=vid_dim, - txt_dim=txt_dim, - emb_dim=emb_dim, - heads=heads, - head_dim=head_dim, - expand_ratio=expand_ratio, - norm=norm, - norm_eps=norm_eps, - ada=ada, - qk_bias=qk_bias, - qk_rope=qk_rope, - qk_norm=qk_norm, - shared_qkv=shared_qkv, - shared_mlp=shared_mlp, - mlp_type=mlp_type, - rope_dim = rope_dim, - window=window[i], - window_method=window_method[i], - temporal_window_size=temporal_window_size[i], - temporal_shifted=temporal_shifted[i], - is_last_layer=(i == num_layers - 1) and not self._7b_version, - rope_type = rope_type, - shared_weights=not ( - (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] - ), - version = self._7b_version, - operations = operations, - **kwargs, - **factory_kwargs - ) - for i in range(num_layers) - ] - ) - self.vid_out = NaPatchOut( - out_channels=vid_out_channels, - patch_size=patch_size, - dim=vid_dim, - device=device, dtype=dtype, operations=operations - ) - - self.need_txt_repeat = block_type[0] in [ - "mmdit_stwin", - "mmdit_stwin_spatial", - "mmdit_stwin_3d_spatial", - ] - - self.vid_out_norm = None - if vid_out_norm is not None: - self.vid_out_norm = CustomRMSNorm( - normalized_shape=vid_dim, - eps=norm_eps, - elementwise_affine=True, - device=device, dtype=dtype - ) - self.vid_out_ada = ada( - dim=vid_dim, - emb_dim=emb_dim, - layers=["out"], - modes=["in"], - device=device, dtype=dtype - ) - - def _resolve_text_conditioning(self, context, cond_or_uncond=None): - if context is None or getattr(context, "numel", lambda: None)() == 0: - context = self.positive_conditioning - return flatten([context]) - if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): - if context.shape[0] == 1: - context = context.squeeze(0) - return flatten([context]) - return flatten(context.unbind(0)) - if context.shape[0] % 2 != 0: - raise ValueError(f"SeedVR2 expected an even text-conditioning batch, got shape {tuple(context.shape)}") - neg_cond, pos_cond = context.chunk(2, dim=0) - if pos_cond.shape[0] == 1: - pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) - return flatten([pos_cond, neg_cond]) - return flatten((*pos_cond.unbind(0), *neg_cond.unbind(0))) - - @staticmethod - def _seedvr2_is_single_conditioning_branch(cond_or_uncond): - if cond_or_uncond is None or len(cond_or_uncond) == 0: - return False - first = cond_or_uncond[0] - return all(entry == first for entry in cond_or_uncond) - - def _swap_pos_neg_halves(self, out, cond_or_uncond=None): - if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): - return out - # ``dim=0`` is explicit on both calls. The contract is "split - # the batch axis into two halves and swap them"; making the - # axis load-bearing in source guards against silent drift if a - # future refactor reorders tensor axes. - pos, neg = out.chunk(2, dim=0) - return torch.cat([neg, pos], dim=0) - - def forward( - self, - x, - timestep, - context, # l c - disable_cache: bool = False, # for test # TODO ? // gives an error when set to True - **kwargs - ): - transformer_options = kwargs.get("transformer_options", {}) - patches_replace = transformer_options.get("patches_replace", {}) - blocks_replace = patches_replace.get("dit", {}) - conditions = kwargs.get("condition") - b, tc, h, w = x.shape - x = x.view(b, 16, -1, h, w) - conditions = conditions.view(b, 17, -1, h, w) - x = x.movedim(1, -1) - conditions = conditions.movedim(1, -1) - cache = Cache(disable=disable_cache) - - txt, txt_shape = self._resolve_text_conditioning(context, transformer_options.get("cond_or_uncond")) - - vid, vid_shape = flatten(x) - cond_latent, _ = flatten(conditions) - - vid = torch.cat([vid, cond_latent], dim=-1) - if txt_shape.size(-1) == 1 and self.need_txt_repeat: - txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) - - txt = self.txt_in(txt) - - vid_shape_before_patchify = vid_shape - vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache) - - emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) - - for i, block in enumerate(self.blocks): - if ("block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] = block( - vid=args["vid"], - txt=args["txt"], - vid_shape=args["vid_shape"], - txt_shape=args["txt_shape"], - emb=args["emb"], - cache=args["cache"], - ) - return out - out = blocks_replace[("block", i)]({ - "vid":vid, - "txt":txt, - "vid_shape":vid_shape, - "txt_shape":txt_shape, - "emb":emb, - "cache":cache, - }, {"original_block": block_wrap}) - vid, txt, vid_shape, txt_shape = out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] - else: - vid, txt, vid_shape, txt_shape = block( - vid=vid, - txt=txt, - vid_shape=vid_shape, - txt_shape=txt_shape, - emb=emb, - cache=cache, - ) - - if self.vid_out_norm: - vid = self.vid_out_norm(vid) - vid = self.vid_out_ada( - vid, - emb=emb, - layer="out", - mode="in", - hid_len=cache("vid_len", lambda: vid_shape.prod(-1)), - cache=cache, - branch_tag="vid", - ) - - vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) - vid = unflatten(vid, vid_shape) - out = torch.stack(vid) - out = out.movedim(-1, 1) - out = rearrange(out, "b c t h w -> b (c t) h w") - return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond")) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py deleted file mode 100644 index 68b11c0ff..000000000 --- a/comfy/ldm/seedvr/vae.py +++ /dev/null @@ -1,2110 +0,0 @@ -from contextlib import nullcontext -from typing import Literal, Optional, Tuple -import gc -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch import Tensor -from contextlib import contextmanager -from comfy.utils import ProgressBar - -from comfy.ldm.seedvr.model import safe_pad_operation -from comfy.ldm.seedvr.constants import ( - BYTEDANCE_BLOCK_OUT_CHANNELS, - BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD, - BYTEDANCE_GN_CHUNKS_FP16, - BYTEDANCE_GN_CHUNKS_FP32, - BYTEDANCE_LOGVAR_CLAMP_MAX, - BYTEDANCE_LOGVAR_CLAMP_MIN, - BYTEDANCE_SLICING_SAMPLE_MIN, - BYTEDANCE_VAE_CONV_MEM_GIB, - BYTEDANCE_VAE_NORM_MEM_GIB, - BYTEDANCE_VAE_SCALING_FACTOR, - BYTEDANCE_VAE_SHIFTING_FACTOR, - BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE, - BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE, - SEEDVR2_LATENT_CHANNELS, -) -from comfy.ldm.modules.attention import optimized_attention -from comfy.ldm.modules.diffusionmodules.model import vae_attention - -import math -from enum import Enum -from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND - -import logging -import comfy.model_management -import comfy.ops -ops = comfy.ops.disable_weight_init - - -def _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap, temporal_scale=1): - if temporal_size is None: - return None - - temporal_size = int(temporal_size) - if temporal_size <= 0: - return 0 - - temporal_overlap = max(0, int(temporal_overlap or 0)) - temporal_overlap = min(temporal_overlap, temporal_size - 1) - temporal_step = temporal_size - temporal_overlap - temporal_scale = max(1, int(temporal_scale)) - return max(1, math.ceil(temporal_step / temporal_scale)) - - -def _seedvr2_clamped_spatial_overlap(overlap, tile_size): - overlap = max(0, int(overlap)) - tile_size = max(1, int(tile_size)) - return min(overlap, tile_size - 1) - - -def _seedvr2_clear_temporal_memory(model): - for module in model.modules(): - if hasattr(module, "memory"): - module.memory = None - - -@torch.inference_mode() -def tiled_vae( - x, - vae_model, - tile_size=(512, 512), - tile_overlap=(64, 64), - temporal_size=16, - temporal_overlap=0, - encode=True, - **kwargs, -): - gc.collect() - comfy.model_management.soft_empty_cache() - - x = x.to(next(vae_model.parameters()).dtype) - if x.ndim != 5: - x = x.unsqueeze(2) - - _, _, d, h, w = x.shape - - sf_s = getattr(vae_model, "spatial_downsample_factor", BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE) - sf_t = getattr(vae_model, "temporal_downsample_factor", BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE) - if encode: - slicing_attr = "slicing_sample_min_size" - slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap) - else: - slicing_attr = "slicing_latent_min_size" - slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap, sf_t) - if encode: - ti_h, ti_w = tile_size - ov_h = _seedvr2_clamped_spatial_overlap(tile_overlap[0], ti_h) - ov_w = _seedvr2_clamped_spatial_overlap(tile_overlap[1], ti_w) - blend_ov_h = max(0, ov_h // sf_s) - blend_ov_w = max(0, ov_w // sf_s) - target_d = (d + sf_t - 1) // sf_t - target_h = (h + sf_s - 1) // sf_s - target_w = (w + sf_s - 1) // sf_s - else: - ti_h = max(1, tile_size[0] // sf_s) - ti_w = max(1, tile_size[1] // sf_s) - ov_h = _seedvr2_clamped_spatial_overlap(tile_overlap[0] // sf_s, ti_h) - ov_w = _seedvr2_clamped_spatial_overlap(tile_overlap[1] // sf_s, ti_w) - blend_ov_h = ov_h * sf_s - blend_ov_w = ov_w * sf_s - - target_d = max(1, d * sf_t - (sf_t - 1)) - target_h = h * sf_s - target_w = w * sf_s - - stride_h = max(1, ti_h - ov_h) - stride_w = max(1, ti_w - ov_w) - - storage_device = vae_model.device - result = None - count = None - def run_temporal_chunks(spatial_tile, model=vae_model, device=storage_device): - device = torch.device(device) - _seedvr2_clear_temporal_memory(model) - t_chunk = spatial_tile.to(device=device, dtype=next(model.parameters()).dtype, non_blocking=True).contiguous() - old_device = getattr(model, "device", None) - model.device = device - old_slicing_min_size = getattr(model, slicing_attr, None) - if old_slicing_min_size is not None and slicing_min_size is not None: - if slicing_min_size <= 0: - setattr(model, slicing_attr, t_chunk.shape[2]) - else: - setattr(model, slicing_attr, slicing_min_size) - try: - if encode: - out = model.encode(t_chunk)[0] - else: - out = model.decode_(t_chunk) - finally: - if old_slicing_min_size is not None and slicing_min_size is not None: - setattr(model, slicing_attr, old_slicing_min_size) - if old_device is not None: - model.device = old_device - if isinstance(out, (tuple, list)): - out = out[0] - if out.ndim == 4: - out = out.unsqueeze(2) - return out.to(storage_device) - - ramp_cache = {} - def get_ramp(steps): - if steps not in ramp_cache: - t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32) - ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi) - return ramp_cache[steps] - - tile_ranges = [] - for y_idx in range(0, h, stride_h): - y_end = min(y_idx + ti_h, h) - if y_idx > 0 and (y_end - y_idx) <= ov_h: - continue - for x_idx in range(0, w, stride_w): - x_end = min(x_idx + ti_w, w) - if x_idx > 0 and (x_end - x_idx) <= ov_w: - continue - tile_ranges.append((y_idx, y_end, x_idx, x_end)) - - total_tiles = len(tile_ranges) - bar = ProgressBar(total_tiles) - single_spatial_tile = h <= ti_h and w <= ti_w - - _seedvr2_clear_temporal_memory(vae_model) - - def run_tile(tile_index, tile_range): - y_idx, y_end, x_idx, x_end = tile_range - tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end] - tile_out = run_temporal_chunks(tile_x) - return tile_index, y_idx, y_end, x_idx, x_end, tile_out - - ordered_tile_outputs = ( - run_tile(tile_index, tile_range) - for tile_index, tile_range in enumerate(tile_ranges) - ) - - for _, y_idx, y_end, x_idx, x_end, tile_out in ordered_tile_outputs: - - if single_spatial_tile: - result = tile_out[:, :, :target_d, :target_h, :target_w] - if result.device != x.device: - result = result.to(x.device).to(x.dtype) - if x.shape[2] == 1 and sf_t == 1: - result = result.squeeze(2) - bar.update(1) - return result - - if result is None: - b_out, c_out = tile_out.shape[0], tile_out.shape[1] - result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) - count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32) - - if encode: - ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] - xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] - cur_ov_h = max(0, min(blend_ov_h, tile_out.shape[3] // 2)) - cur_ov_w = max(0, min(blend_ov_w, tile_out.shape[4] // 2)) - else: - ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] - xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] - cur_ov_h = max(0, min(blend_ov_h, tile_out.shape[3] // 2)) - cur_ov_w = max(0, min(blend_ov_w, tile_out.shape[4] // 2)) - - w_h = torch.ones((tile_out.shape[3],), device=storage_device) - w_w = torch.ones((tile_out.shape[4],), device=storage_device) - - if cur_ov_h > 0: - r = get_ramp(cur_ov_h) - if y_idx > 0: - w_h[:cur_ov_h] = r - if y_end < h: - w_h[-cur_ov_h:] = 1.0 - r - - if cur_ov_w > 0: - r = get_ramp(cur_ov_w) - if x_idx > 0: - w_w[:cur_ov_w] = r - if x_end < w: - w_w[-cur_ov_w:] = 1.0 - r - - final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) - - valid_d = min(tile_out.shape[2], result.shape[2]) - tile_out = tile_out[:, :, :valid_d, :, :] - - tile_out.mul_(final_weight) - - result[:, :, :valid_d, ys:ye, xs:xe] += tile_out - count[:, :, :, ys:ye, xs:xe] += final_weight - - del tile_out, final_weight, w_h, w_w - bar.update(1) - - result.div_(count.clamp(min=1e-6)) - _seedvr2_clear_temporal_memory(vae_model) - - if result.device != x.device: - result = result.to(x.device).to(x.dtype) - - if x.shape[2] == 1 and sf_t == 1: - result = result.squeeze(2) - - return result - -_NORM_LIMIT = float("inf") -def get_norm_limit(): - return _NORM_LIMIT - - -def set_norm_limit(value: Optional[float] = None): - global _NORM_LIMIT - if value is None: - value = float("inf") - _NORM_LIMIT = value - -@contextmanager -def ignore_padding(model): - orig_padding = model.padding - model.padding = (0, 0, 0) - try: - yield - finally: - model.padding = orig_padding - -class MemoryState(Enum): - DISABLED = 0 - INITIALIZING = 1 - ACTIVE = 2 - UNSET = 3 - -def get_cache_size(conv_module, input_len, pad_len, dim=0): - dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 - output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 - remain_len = ( - input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) - ) - overlap_len = dilated_kernerl_size - conv_module.stride[dim] - cache_len = overlap_len + remain_len # >= 0 - - assert output_len > 0 - return cache_len - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters: torch.Tensor, deterministic: bool = False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, BYTEDANCE_LOGVAR_CLAMP_MIN, BYTEDANCE_LOGVAR_CLAMP_MAX) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like( - self.mean, device=self.parameters.device, dtype=self.parameters.dtype - ) - - def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: - sample = torch.randn( - self.mean.shape, - generator=generator, - device=self.parameters.device, - dtype=self.parameters.dtype, - ) - x = self.mean + self.std * sample - return x - - def mode(self): - return self.mean - -class SpatialNorm(nn.Module): - def __init__( - self, - f_channels: int, - zq_channels: int, - ): - super().__init__() - self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) - self.conv_y = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - self.conv_b = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: - f_size = f.shape[-2:] - zq = F.interpolate(zq, size=f_size, mode="nearest") - norm_f = self.norm_layer(f) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) - return new_f - -# partial implementation of diffusers's Attention for comfyui -class Attention(nn.Module): - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - kv_heads: Optional[int] = None, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - added_kv_proj_dim: Optional[int] = None, - added_proj_bias: Optional[bool] = True, - norm_num_groups: Optional[int] = None, - spatial_norm_dim: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - out_dim: int = None, - out_context_dim: int = None, - context_pre_only=None, - pre_only=False, - is_causal: bool = False, - ): - super().__init__() - - self.inner_dim = out_dim if out_dim is not None else dim_head * heads - self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads - self.query_dim = query_dim - self.use_bias = bias - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - self.rescale_output_factor = rescale_output_factor - self.residual_connection = residual_connection - self.dropout = dropout - self.fused_projections = False - self.out_dim = out_dim if out_dim is not None else query_dim - self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim - self.context_pre_only = context_pre_only - self.pre_only = pre_only - self.is_causal = is_causal - - # we make use of this private variable to know whether this class is loaded - # with an deprecated state dict so that we can convert it on the fly - self._from_deprecated_attn_block = _from_deprecated_attn_block - - self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - - self.heads = out_dim // dim_head if out_dim is not None else heads - self.sliceable_head_dim = heads - - self.added_kv_proj_dim = added_kv_proj_dim - self.only_cross_attention = only_cross_attention - - if norm_num_groups is not None: - self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) - else: - self.group_norm = None - - if spatial_norm_dim is not None: - self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) - else: - self.spatial_norm = None - - self.norm_q = None - self.norm_k = None - - self.norm_cross = None - self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias) - - if not self.only_cross_attention: - # only relevant for the `AddedKVProcessor` classes - self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - else: - self.to_k = None - self.to_v = None - - self.added_proj_bias = added_proj_bias - if self.added_kv_proj_dim is not None: - self.add_k_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) - self.add_v_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) - if self.context_pre_only is not None: - self.add_q_proj = ops.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - else: - self.add_q_proj = None - self.add_k_proj = None - self.add_v_proj = None - - if not self.pre_only: - self.to_out = nn.ModuleList([]) - self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - else: - self.to_out = None - - if self.context_pre_only is not None and not self.context_pre_only: - self.to_add_out = ops.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) - else: - self.to_add_out = None - - self.norm_added_q = None - self.norm_added_k = None - self.optimized_vae_attention = vae_attention() - - def __call__( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ) -> torch.Tensor: - - residual = hidden_states - if self.spatial_norm is not None: - hidden_states = self.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1]) - - if self.group_norm is not None: - hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = self.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif self.norm_cross: - encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) - - key = self.to_k(encoder_hidden_states) - value = self.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // self.heads - - query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - - if self.norm_q is not None: - query = self.norm_q(query) - if self.norm_k is not None: - key = self.norm_k(key) - - if input_ndim == 4 and encoder_hidden_states is hidden_states and attention_mask is None and self.heads == 1: - query = query.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) - key = key.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) - value = value.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) - hidden_states = self.optimized_vae_attention(query, key, value).reshape(batch_size, self.heads, head_dim, height * width).transpose(2, 3) - else: - hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if self.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / self.rescale_output_factor - - return hidden_states - - -def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor): - with torch.no_grad(): - depth = weight_3d.size(2) - weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) - return weight_3d - -def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor): - with torch.no_grad(): - bias_3d.copy_(bias_2d) - return bias_3d - - -def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): - weight_name = prefix + "weight" - bias_name = prefix + "bias" - if weight_name in state_dict: - weight_2d = state_dict[weight_name] - if weight_2d.dim() == 4: - weight_3d = inflate_weight_fn( - weight_2d=weight_2d, - weight_3d=layer.weight, - ) - state_dict[weight_name] = weight_3d - else: - return state_dict - if bias_name in state_dict: - bias_2d = state_dict[bias_name] - if bias_2d.dim() == 1: - bias_3d = inflate_bias_fn( - bias_2d=bias_2d, - bias_3d=layer.bias, - ) - state_dict[bias_name] = bias_3d - return state_dict - -def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: - input_dtype = x.dtype - if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)): - if x.ndim == 4: - x = rearrange(x, "b c h w -> b h w c") - x = norm_layer(x) - x = rearrange(x, "b h w c -> b c h w") - return x.to(input_dtype) - if x.ndim == 5: - x = rearrange(x, "b c t h w -> b t h w c") - x = norm_layer(x) - x = rearrange(x, "b t h w c -> b c t h w") - return x.to(input_dtype) - if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): - if x.ndim <= 4: - return norm_layer(x).to(input_dtype) - if x.ndim == 5: - t = x.size(2) - x = rearrange(x, "b c t h w -> (b t) c h w") - memory_occupy = x.numel() * x.element_size() / 1024**3 - if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit(): - num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups) - assert norm_layer.num_groups % num_chunks == 0 - num_groups_per_chunk = norm_layer.num_groups // num_chunks - - x = list(x.chunk(num_chunks, dim=1)) - weights = norm_layer.weight.chunk(num_chunks, dim=0) - biases = norm_layer.bias.chunk(num_chunks, dim=0) - for i, (w, b) in enumerate(zip(weights, biases)): - x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) - x[i] = x[i].to(input_dtype) - x = torch.cat(x, dim=1) - else: - x = norm_layer(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - return x.to(input_dtype) - raise NotImplementedError - -def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): - problematic_modes = ['bilinear', 'bicubic', 'trilinear'] - - if mode in problematic_modes: - try: - return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ) - except RuntimeError as e: - if ("not implemented for 'Half'" in str(e) or - "compute_indices_weights" in str(e)): - original_dtype = x.dtype - return F.interpolate( - x.float(), - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ).to(original_dtype) - else: - raise e - else: - # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire - return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ) - -_receptive_field_t = Literal["half", "full"] - -def extend_head(tensor, times: int = 2, memory = None): - if memory is not None: - return torch.cat((memory.to(tensor), tensor), dim=2) - assert times >= 0, "Invalid input for function 'extend_head'!" - if times == 0: - return tensor - else: - tile_repeat = [1] * tensor.ndim - tile_repeat[2] = times - return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2) - -def cache_send_recv(tensor, cache_size, times, memory=None): - recv_buffer = None - - if memory is not None: - recv_buffer = memory.to(tensor[0]) - elif times > 0: - tile_repeat = [1] * tensor[0].ndim - tile_repeat[2] = times - recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) - - return recv_buffer - -class InflatedCausalConv3d(ops.Conv3d): - def __init__( - self, - *args, - inflation_mode, - memory_device = "same", - **kwargs, - ): - self.inflation_mode = inflation_mode - self.memory = None - super().__init__(*args, **kwargs) - self.temporal_padding = self.padding[0] - self.memory_device = memory_device - self.padding = (0, *self.padding[1:]) - self.memory_limit = float("inf") - self.logged_once = False - - def set_memory_limit(self, value: float): - self.memory_limit = value - - def set_memory_device(self, memory_device): - self.memory_device = memory_device - - def _conv_forward(self, input, weight, bias, *args, **kwargs): - if (NVIDIA_MEMORY_CONV_BUG_WORKAROUND and - weight.dtype in (torch.float16, torch.bfloat16) and - hasattr(torch.backends.cudnn, 'is_available') and - torch.backends.cudnn.is_available() and - getattr(torch.backends.cudnn, 'enabled', True)): - try: - out = torch.cudnn_convolution( - input, weight, self.padding, self.stride, self.dilation, self.groups, - benchmark=False, deterministic=False, allow_tf32=True - ) - if bias is not None: - out += bias.reshape((1, -1) + (1,) * (out.ndim - 2)) - return out - except RuntimeError: - pass - except NotImplementedError: - pass - try: - return super()._conv_forward(input, weight, bias, *args, **kwargs) - except NotImplementedError: - # for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend - if not self.logged_once: - logging.warning("VAE is on CPU for decoding. This is most likely due to not enough memory") - self.logged_once = True - return F.conv3d(input, weight, bias, *args, **kwargs) - - def memory_limit_conv( - self, - x, - *, - split_dim=3, - padding=(0, 0, 0, 0, 0, 0), - prev_cache=None, - ): - # Compatible with no limit. - if math.isinf(self.memory_limit): - if prev_cache is not None: - x = torch.cat([prev_cache, x], dim=split_dim - 1) - return super().forward(x) - - # Compute tensor shape after concat & padding. - shape = torch.tensor(x.size()) - if prev_cache is not None: - shape[split_dim - 1] += prev_cache.size(split_dim - 1) - shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0) - memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB - if memory_occupy < self.memory_limit or split_dim == x.ndim: - x_concat = x - if prev_cache is not None: - x_concat = torch.cat([prev_cache, x], dim=split_dim - 1) - - def pad_and_forward(): - padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) - if not padded.is_contiguous(): - padded = padded.contiguous() - with ignore_padding(self): - return torch.nn.Conv3d.forward(self, padded) - - return pad_and_forward() - - num_splits = math.ceil(memory_occupy / self.memory_limit) - size_per_split = x.size(split_dim) // num_splits - split_sizes = [size_per_split] * (num_splits - 1) - split_sizes += [x.size(split_dim) - sum(split_sizes)] - - x = list(x.split(split_sizes, dim=split_dim)) - if prev_cache is not None: - prev_cache = list(prev_cache.split(split_sizes, dim=split_dim)) - cache = None - for idx in range(len(x)): - if prev_cache is not None: - x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1) - - lpad_dim = (x[idx].ndim - split_dim - 1) * 2 - rpad_dim = lpad_dim + 1 - padding = list(padding) - padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0 - padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0 - pad_len = padding[lpad_dim] + padding[rpad_dim] - padding = tuple(padding) - - next_cache = None - cache_len = cache.size(split_dim) if cache is not None else 0 - next_catch_size = get_cache_size( - conv_module=self, - input_len=x[idx].size(split_dim) + cache_len, - pad_len=pad_len, - dim=split_dim - 2, - ) - if next_catch_size != 0: - assert next_catch_size <= x[idx].size(split_dim) - next_cache = ( - x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) - ) - - x[idx] = self.memory_limit_conv( - x[idx], - split_dim=split_dim + 1, - padding=padding, - prev_cache=cache - ) - - cache = next_cache - - output = torch.cat(x, dim=split_dim) - return output - - def forward( - self, - input, - memory_state: MemoryState = MemoryState.UNSET - ) -> Tensor: - assert memory_state != MemoryState.UNSET - if memory_state != MemoryState.ACTIVE: - self.memory = None - if ( - math.isinf(self.memory_limit) - and torch.is_tensor(input) - ): - return self.basic_forward(input, memory_state) - return self.slicing_forward(input, memory_state) - - def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET): - mem_size = self.stride[0] - self.kernel_size[0] - if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): - input = extend_head(input, memory=self.memory, times=-1) - else: - input = extend_head(input, times=self.temporal_padding * 2) - memory = ( - input[:, :, mem_size:].detach() - if (mem_size != 0 and memory_state != MemoryState.DISABLED) - else None - ) - if ( - memory_state != MemoryState.DISABLED - and not self.training - and (self.memory_device is not None) - ): - self.memory = memory - if self.memory_device == "cpu" and self.memory is not None: - self.memory = self.memory.to("cpu") - return super().forward(input) - - def slicing_forward( - self, - input, - memory_state: MemoryState = MemoryState.UNSET, - ) -> Tensor: - squeeze_out = False - if torch.is_tensor(input): - input = [input] - squeeze_out = True - - cache_size = self.kernel_size[0] - self.stride[0] - cache = cache_send_recv( - input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2 - ) - - # Single GPU inference - simplified memory management - if ( - memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing - and not self.training - and (self.memory_device is not None) - and cache_size != 0 - ): - if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: - input[0] = torch.cat([cache, input[0]], dim=2) - cache = None - if cache_size <= input[-1].size(2): - self.memory = input[-1][:, :, -cache_size:].detach().contiguous() - if self.memory_device == "cpu" and self.memory is not None: - self.memory = self.memory.to("cpu") - - padding = tuple(x for x in reversed(self.padding) for _ in range(2)) - for i in range(len(input)): - # Prepare cache for next input slice. - next_cache = None - cache_size = 0 - if i < len(input) - 1: - cache_len = cache.size(2) if cache is not None else 0 - cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0) - if cache_size != 0: - if cache_size > input[i].size(2) and cache is not None: - input[i] = torch.cat([cache, input[i]], dim=2) - cache = None - assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" - next_cache = input[i][:, :, -cache_size:] - - # Conv forward for this input slice. - input[i] = self.memory_limit_conv( - input[i], - padding=padding, - prev_cache=cache - ) - - # Update cache. - cache = next_cache - - return input[0] if squeeze_out else input - -def remove_head(tensor: Tensor, times: int = 1) -> Tensor: - if times == 0: - return tensor - return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) - -class Upsample3D(nn.Module): - - def __init__( - self, - channels, - out_channels = None, - inflation_mode = "tail", - temporal_up: bool = False, - spatial_up: bool = True, - slicing: bool = False, - interpolate = True, - name: str = "conv", - use_conv_transpose = False, - use_conv: bool = False, - padding = 1, - bias = True, - kernel_size = None, - **kwargs, - ): - super().__init__() - self.interpolate = interpolate - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv_transpose = use_conv_transpose - self.use_conv = use_conv - self.name = name - - self.conv = None - if use_conv_transpose: - if kernel_size is None: - kernel_size = 4 - self.conv = ops.ConvTranspose2d( - channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias - ) - elif use_conv: - if kernel_size is None: - kernel_size = 3 - self.conv = ops.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) - - conv = self.conv if self.name == "conv" else self.Conv2d_0 - - # Note: lora_layer is not passed into constructor in the original implementation. - # So we make a simplification. - conv = InflatedCausalConv3d( - self.channels, - self.out_channels, - 3, - padding=1, - inflation_mode=inflation_mode, - ) - - self.temporal_up = temporal_up - self.spatial_up = spatial_up - self.temporal_ratio = 2 if temporal_up else 1 - self.spatial_ratio = 2 if spatial_up else 1 - self.slicing = slicing - - assert not self.interpolate - # [Override] MAGViT v2 implementation - if not self.interpolate: - upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio - self.upscale_conv = ops.Conv3d( - self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 - ) - identity = ( - torch.eye(self.channels) - .repeat(upscale_ratio, 1) - .reshape_as(self.upscale_conv.weight) - ) - self.upscale_conv.weight.data.copy_(identity) - - if self.name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - self.norm = None - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state=None, - **kwargs, - ) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels - - if hasattr(self, "norm") and self.norm is not None: - # [Overridden] change to causal norm. - hidden_states = causal_norm_wrapper(self.norm, hidden_states) - - if self.use_conv_transpose: - return self.conv(hidden_states) - - if self.slicing: - split_size = hidden_states.size(2) // 2 - hidden_states = list( - hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) - ) - else: - hidden_states = [hidden_states] - - for i in range(len(hidden_states)): - hidden_states[i] = self.upscale_conv(hidden_states[i]) - hidden_states[i] = rearrange( - hidden_states[i], - "b (x y z c) f h w -> b c (f z) (h x) (w y)", - x=self.spatial_ratio, - y=self.spatial_ratio, - z=self.temporal_ratio, - ) - - if self.temporal_up and memory_state != MemoryState.ACTIVE: - hidden_states[0] = remove_head(hidden_states[0]) - - if not self.slicing: - hidden_states = hidden_states[0] - - if self.use_conv: - if self.name == "conv": - hidden_states = self.conv(hidden_states, memory_state=memory_state) - else: - hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state) - - if not self.slicing: - return hidden_states - else: - return torch.cat(hidden_states, dim=2) - - -class Downsample3D(nn.Module): - """A 3D downsampling layer with an optional convolution.""" - - def __init__( - self, - channels, - out_channels = None, - inflation_mode = "tail", - spatial_down: bool = False, - temporal_down: bool = False, - name: str = "conv", - kernel_size=3, - use_conv: bool = False, - padding = 1, - bias=True, - **kwargs, - ): - super().__init__() - self.padding = padding - self.name = name - self.channels = channels - self.out_channels = out_channels or channels - self.temporal_down = temporal_down - self.spatial_down = spatial_down - self.use_conv = use_conv - self.padding = padding - - self.temporal_ratio = 2 if temporal_down else 1 - self.spatial_ratio = 2 if spatial_down else 1 - - self.temporal_kernel = 3 if temporal_down else 1 - self.spatial_kernel = 3 if spatial_down else 1 - - if use_conv: - conv = InflatedCausalConv3d( - self.channels, - self.out_channels, - kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), - stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - padding=( - 1 if self.temporal_down else 0, - self.padding if self.spatial_down else 0, - self.padding if self.spatial_down else 0, - ), - inflation_mode=inflation_mode, - ) - else: - assert self.channels == self.out_channels - conv = nn.AvgPool3d( - kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - ) - - self.conv = conv - - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state = None, - **kwargs, - ) -> torch.FloatTensor: - - assert hidden_states.shape[1] == self.channels - - if hasattr(self, "norm") and self.norm is not None: - # [Overridden] change to causal norm. - hidden_states = causal_norm_wrapper(self.norm, hidden_states) - - if self.use_conv and self.padding == 0 and self.spatial_down: - pad = (0, 1, 0, 1) - hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0) - - assert hidden_states.shape[1] == self.channels - - hidden_states = self.conv(hidden_states, memory_state=memory_state) - - return hidden_states - - -class ResnetBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - groups_out: Optional[int] = None, - eps: float = 1e-6, - non_linearity: str = "swish", - time_embedding_norm: str = "default", - output_scale_factor: float = 1.0, - skip_time_act: bool = False, - use_in_shortcut: Optional[bool] = None, - up: bool = False, - down: bool = False, - conv_shortcut_bias: bool = True, - conv_2d_out_channels: Optional[int] = None, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - slicing: bool = False, - **kwargs, - ): - super().__init__() - self.up = up - self.down = down - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - conv_2d_out_channels = conv_2d_out_channels or out_channels - self.use_in_shortcut = use_in_shortcut - self.output_scale_factor = output_scale_factor - self.skip_time_act = skip_time_act - self.nonlinearity = nn.SiLU() - if temb_channels is not None: - self.time_emb_proj = ops.Linear(temb_channels, out_channels) - else: - self.time_emb_proj = None - self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - if groups_out is None: - groups_out = groups - self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - self.use_in_shortcut = self.in_channels != out_channels - self.dropout = torch.nn.Dropout(dropout) - self.conv1 = InflatedCausalConv3d( - self.in_channels, - self.out_channels, - kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), - stride=1, - padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), - inflation_mode=inflation_mode, - ) - - self.conv2 = InflatedCausalConv3d( - self.out_channels, - conv_2d_out_channels, - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.upsample = self.downsample = None - if self.up: - self.upsample = Upsample3D( - self.in_channels, - use_conv=False, - inflation_mode=inflation_mode, - slicing=slicing, - ) - elif self.down: - self.downsample = Downsample3D( - self.in_channels, - use_conv=False, - padding=1, - name="op", - inflation_mode=inflation_mode, - ) - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = InflatedCausalConv3d( - self.in_channels, - conv_2d_out_channels, - kernel_size=1, - stride=1, - padding=0, - bias=True, - inflation_mode=inflation_mode, - ) - - def forward( - self, input_tensor, temb, memory_state = None, **kwargs - ): - hidden_states = input_tensor - - hidden_states = causal_norm_wrapper(self.norm1, hidden_states) - - hidden_states = self.nonlinearity(hidden_states) - - if self.upsample is not None: - if hidden_states.shape[0] >= BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD: - input_tensor = input_tensor.contiguous() - hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor, memory_state=memory_state) - hidden_states = self.upsample(hidden_states, memory_state=memory_state) - elif self.downsample is not None: - input_tensor = self.downsample(input_tensor, memory_state=memory_state) - hidden_states = self.downsample(hidden_states, memory_state=memory_state) - - hidden_states = self.conv1(hidden_states, memory_state=memory_state) - - if self.time_emb_proj is not None: - if not self.skip_time_act: - temb = self.nonlinearity(temb) - temb = self.time_emb_proj(temb)[:, :, None, None] - - if temb is not None: - hidden_states = hidden_states + temb - - hidden_states = causal_norm_wrapper(self.norm2, hidden_states) - - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, memory_state=memory_state) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) - - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - - return output_tensor - - -class DownEncoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_down: bool = True, - spatial_down: bool = True, - ): - super().__init__() - resnets = [] - temporal_modules = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - # [Override] Replace module. - ResnetBlock3D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ) - temporal_modules.append(nn.Identity()) - - self.resnets = nn.ModuleList(resnets) - self.temporal_modules = nn.ModuleList(temporal_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample3D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - temporal_down=temporal_down, - spatial_down=spatial_down, - inflation_mode=inflation_mode, - ) - ] - ) - else: - self.downsamplers = None - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state = None, - **kwargs, - ) -> torch.FloatTensor: - for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) - hidden_states = temporal(hidden_states) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, memory_state=memory_state) - - return hidden_states - - -class UpDecoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - temb_channels: Optional[int] = None, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_up: bool = True, - spatial_up: bool = True, - slicing: bool = False, - ): - super().__init__() - resnets = [] - temporal_modules = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - # [Override] Replace module. - ResnetBlock3D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - slicing=slicing, - ) - ) - - temporal_modules.append(nn.Identity()) - - self.resnets = nn.ModuleList(resnets) - self.temporal_modules = nn.ModuleList(temporal_modules) - - if add_upsample: - # [Override] Replace module & use learnable upsample - self.upsamplers = nn.ModuleList( - [ - Upsample3D( - out_channels, - use_conv=True, - out_channels=out_channels, - temporal_up=temporal_up, - spatial_up=spatial_up, - interpolate=False, - inflation_mode=inflation_mode, - slicing=slicing, - ) - ] - ) - else: - self.upsamplers = None - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - memory_state=None - ) -> torch.FloatTensor: - for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) - hidden_states = temporal(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, memory_state=memory_state) - - return hidden_states - - -class UNetMidBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - add_attention: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = 1.0, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - ): - super().__init__() - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - self.add_attention = add_attention - - # there is always at least one resnet - resnets = [ - # [Override] Replace module. - ResnetBlock3D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ] - attentions = [] - - if attention_head_dim is None: - attention_head_dim = in_channels - - for _ in range(num_layers): - if self.add_attention: - attentions.append( - Attention( - in_channels, - heads=in_channels // attention_head_dim, - dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=( - resnet_groups if resnet_time_scale_shift == "default" else None - ), - spatial_norm_dim=( - temb_channels if resnet_time_scale_shift == "spatial" else None - ), - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - else: - attentions.append(None) - - resnets.append( - ResnetBlock3D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None, memory_state=None): - video_length, frame_height, frame_width = hidden_states.size()[-3:] - hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") - hidden_states = attn(hidden_states, temb=temb) - hidden_states = rearrange( - hidden_states, "(b f) c h w -> b c f h w", f=video_length - ) - hidden_states = resnet(hidden_states, temb, memory_state=memory_state) - - return hidden_states - - -class Encoder3D(nn.Module): - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - double_z: bool = True, - mid_block_add_attention=True, - # [Override] add extra_cond_dim, temporal down num - temporal_down_num: int = 2, - extra_cond_dim: int = None, - gradient_checkpoint: bool = False, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - ): - super().__init__() - self.layers_per_block = layers_per_block - self.temporal_down_num = temporal_down_num - - self.conv_in = InflatedCausalConv3d( - in_channels, - block_out_channels[0], - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.mid_block = None - self.down_blocks = nn.ModuleList([]) - self.extra_cond_dim = extra_cond_dim - - self.conv_extra_cond = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - # [Override] to support temporal down block design - is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 - # Note: take the last ones - - assert down_block_type == "DownEncoderBlock3D" - - down_block = DownEncoderBlock3D( - num_layers=self.layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - resnet_eps=1e-6, - downsample_padding=0, - # Note: Don't know why set it as 0 - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - temporal_down=is_temporal_down_block, - spatial_down=True, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - self.down_blocks.append(down_block) - - def zero_module(module): - # Zero out the parameters of a module and return it. - for p in module.parameters(): - p.detach().zero_() - return module - - self.conv_extra_cond.append( - zero_module( - ops.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) - ) - if self.extra_cond_dim is not None and self.extra_cond_dim > 0 - else None - ) - - # mid - self.mid_block = UNetMidBlock3D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=None, - add_attention=mid_block_add_attention, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # out - self.conv_norm_out = ops.GroupNorm( - num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 - ) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = InflatedCausalConv3d( - block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode - ) - - self.gradient_checkpointing = gradient_checkpoint - - def forward( - self, - sample: torch.FloatTensor, - extra_cond=None, - memory_state = None - ) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample, memory_state = memory_state) - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - # down - # [Override] add extra block and extra cond - for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), sample, use_reentrant=False - ) - if extra_block is not None: - sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) - - # middle - sample = self.mid_block(sample) - - else: - # down - # [Override] add extra block and extra cond - for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): - sample = down_block(sample, memory_state=memory_state) - if extra_block is not None: - sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) - - # middle - sample = self.mid_block(sample, memory_state=memory_state) - - # post-process - sample = causal_norm_wrapper(self.conv_norm_out, sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state = memory_state) - - return sample - - -class Decoder3D(nn.Module): - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - norm_type: str = "group", # group, spatial - mid_block_add_attention=True, - # [Override] add temporal up block - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_up_num: int = 2, - slicing_up_num: int = 0, - gradient_checkpoint: bool = False, - ): - super().__init__() - self.layers_per_block = layers_per_block - self.temporal_up_num = temporal_up_num - - self.conv_in = InflatedCausalConv3d( - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - temb_channels = in_channels if norm_type == "spatial" else None - - # mid - self.mid_block = UNetMidBlock3D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default" if norm_type == "group" else norm_type, - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=temb_channels, - add_attention=mid_block_add_attention, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - is_temporal_up_block = i < self.temporal_up_num - is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num - # Note: Keep symmetric - - assert up_block_type == "UpDecoderBlock3D" - up_block = UpDecoderBlock3D( - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - add_upsample=not is_final_block, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_time_scale_shift=norm_type, - temb_channels=temb_channels, - temporal_up=is_temporal_up_block, - slicing=is_slicing_up_block, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_type == "spatial": - self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) - else: - self.conv_norm_out = ops.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 - ) - self.conv_act = nn.SiLU() - self.conv_out = InflatedCausalConv3d( - block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode - ) - - self.gradient_checkpointing = gradient_checkpoint - - # Note: Just copy from Decoder. - def forward( - self, - sample: torch.FloatTensor, - latent_embeds: Optional[torch.FloatTensor] = None, - memory_state = None, - ) -> torch.FloatTensor: - - sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample, memory_state=memory_state) - - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - # middle - sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = up_block(sample, latent_embeds, memory_state=memory_state) - - # post-process - sample = causal_norm_wrapper(self.conv_norm_out, sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state=memory_state) - - return sample - -class VideoAutoencoderKL(nn.Module): - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - layers_per_block: int = 2, - act_fn: str = "silu", - latent_channels: int = SEEDVR2_LATENT_CHANNELS, - norm_num_groups: int = 32, - attention: bool = True, - temporal_scale_num: int = 2, - slicing_up_num: int = 0, - gradient_checkpoint: bool = False, - inflation_mode = "pad", - time_receptive_field: _receptive_field_t = "full", - use_quant_conv: bool = False, - use_post_quant_conv: bool = False, - slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN, - *args, - **kwargs, - ): - self.slicing_sample_min_size = slicing_sample_min_size - self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) - extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None - block_out_channels = BYTEDANCE_BLOCK_OUT_CHANNELS - down_block_types = ("DownEncoderBlock3D",) * 4 - up_block_types = ("UpDecoderBlock3D",) * 4 - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder3D( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=True, - extra_cond_dim=extra_cond_dim, - # [Override] add temporal_down_num parameter - temporal_down_num=temporal_scale_num, - gradient_checkpoint=gradient_checkpoint, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # pass init params to Decoder - self.decoder = Decoder3D( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - act_fn=act_fn, - # [Override] add temporal_up_num parameter - temporal_up_num=temporal_scale_num, - slicing_up_num=slicing_up_num, - gradient_checkpoint=gradient_checkpoint, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - self.quant_conv = ( - InflatedCausalConv3d( - in_channels=2 * latent_channels, - out_channels=2 * latent_channels, - kernel_size=1, - inflation_mode=inflation_mode, - ) - if use_quant_conv - else None - ) - self.post_quant_conv = ( - InflatedCausalConv3d( - in_channels=latent_channels, - out_channels=latent_channels, - kernel_size=1, - inflation_mode=inflation_mode, - ) - if use_post_quant_conv - else None - ) - - # A hacky way to remove attention. - if not attention: - self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) - self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) - - self.use_slicing = True - - def encode(self, x: torch.FloatTensor, return_dict: bool = True): - h = self.slicing_encode(x) - posterior = DiagonalGaussianDistribution(h).mode() - - if not return_dict: - return (posterior,) - - return posterior - - def decode_( - self, z: torch.Tensor, return_dict: bool = True - ): - decoded = self.slicing_decode(z) - - if not return_dict: - return (decoded,) - - return decoded - - def _encode( - self, x, memory_state = MemoryState.DISABLED - ) -> torch.Tensor: - _x = x.to(self.device) - h = self.encoder(_x, memory_state=memory_state) - if self.quant_conv is not None: - output = self.quant_conv(h, memory_state=memory_state) - else: - output = h - return output.to(x.device) - - def _decode( - self, z, memory_state = MemoryState.DISABLED - ) -> torch.Tensor: - _z = z.to(self.device) - - if self.post_quant_conv is not None: - _z = self.post_quant_conv(_z, memory_state=memory_state) - - output = self.decoder(_z, memory_state=memory_state) - return output.to(z.device) - - def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: - sp_size =1 - if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: - split_size = max( - self.slicing_sample_min_size * sp_size, - getattr(self, "temporal_downsample_factor", 1), - ) - x_slices = list(x[:, :, 1:].split(split_size=split_size, dim=2)) - min_active_len = getattr(self, "temporal_downsample_factor", 1) - if len(x_slices) > 1 and x_slices[-1].shape[2] < min_active_len: - x_slices[-2] = torch.cat((x_slices[-2], x_slices[-1]), dim=2) - x_slices.pop() - encoded_slices = [ - self._encode( - torch.cat((x[:, :, :1], x_slices[0]), dim=2), - memory_state=MemoryState.INITIALIZING, - ) - ] - for x_idx in range(1, len(x_slices)): - encoded_slices.append( - self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) - ) - out = torch.cat(encoded_slices, dim=2) - modules_with_memory = [m for m in self.modules() - if isinstance(m, InflatedCausalConv3d) and m.memory is not None] - for m in modules_with_memory: - m.memory = None - return out - else: - return self._encode(x) - - def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: - sp_size = 1 - if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: - z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) - decoded_slices = [ - self._decode( - torch.cat((z[:, :, :1], z_slices[0]), dim=2), - memory_state=MemoryState.INITIALIZING - ) - ] - for z_idx in range(1, len(z_slices)): - decoded_slices.append( - self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) - ) - out = torch.cat(decoded_slices, dim=2) - modules_with_memory = [m for m in self.modules() - if isinstance(m, InflatedCausalConv3d) and m.memory is not None] - for m in modules_with_memory: - m.memory = None - return out - else: - return self._decode(z) - - def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - raise NotImplementedError - - def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: - raise NotImplementedError - - def forward( - self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs - ): - # x: [b c t h w] - def _unwrap(value): - return value[0] if isinstance(value, tuple) else value - - if mode == "encode": - return _unwrap(self.encode(x)) - elif mode == "decode": - return _unwrap(self.decode_(x)) - else: - latent = _unwrap(self.encode(x)) - return _unwrap(self.decode_(latent)) - -class VideoAutoencoderKLWrapper(VideoAutoencoderKL): - def __init__( - self, - *args, - spatial_downsample_factor = 8, - temporal_downsample_factor = 4, - freeze_encoder = True, - **kwargs, - ): - self.spatial_downsample_factor = spatial_downsample_factor - self.temporal_downsample_factor = temporal_downsample_factor - self.freeze_encoder = freeze_encoder - self.enable_tiling = False - super().__init__(*args, **kwargs) - self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB) - - def forward(self, x: torch.FloatTensor): - with torch.no_grad() if self.freeze_encoder else nullcontext(): - z, p = self.encode(x) - x = self.decode(z) - return x, z, p - - def encode(self, x, orig_dims=None): - if x.ndim == 4: - x = x.unsqueeze(2) - x = x.to(dtype=next(self.parameters()).dtype) - self.device = x.device - p = super().encode(x) - z = p.squeeze(2) - return z, p - - def decode(self, z, seedvr2_tiling=None): - seedvr2_tiling = {} if seedvr2_tiling is None else seedvr2_tiling - if not isinstance(seedvr2_tiling, dict): - raise RuntimeError( - "SeedVR2 VideoAutoencoderKLWrapper.decode: `seedvr2_tiling` must be a dict; " - f"got {type(seedvr2_tiling).__name__} with value {seedvr2_tiling!r}." - ) - - if z.ndim == 5: - b, c, t_latent, h, w = z.shape - if c != 16: - raise RuntimeError( - "SeedVR2 VideoAutoencoderKLWrapper.decode: 5-D latent input must " - f"have 16 channels; got shape {tuple(z.shape)}." - ) - latent = z - elif z.ndim == 4: - b, tc, h, w = z.shape - if tc % 16 != 0: - raise RuntimeError( - "SeedVR2 VideoAutoencoderKLWrapper.decode: 4-D latent input must " - "use collapsed channel layout (B, 16*T, H, W); " - f"got shape {tuple(z.shape)}." - ) - latent = z.reshape(b, 16, -1, h, w) - else: - raise RuntimeError( - "SeedVR2 VideoAutoencoderKLWrapper.decode: latent input must be " - "4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); " - f"got shape {tuple(z.shape)}." - ) - scale = BYTEDANCE_VAE_SCALING_FACTOR - shift = BYTEDANCE_VAE_SHIFTING_FACTOR - latent = latent / scale + shift - - self.device = latent.device - self.enable_tiling = seedvr2_tiling.get("enable_tiling", False) - - if self.enable_tiling: - decode_seedvr2_args = dict(seedvr2_tiling) - tile_h, tile_w = decode_seedvr2_args.get("tile_size", (512, 512)) - ov_h, ov_w = decode_seedvr2_args.get("tile_overlap", (64, 64)) - decode_seedvr2_args["tile_overlap"] = ( - min(ov_h, max(0, tile_h - 8)), - min(ov_w, max(0, tile_w - 8)), - ) - x = tiled_vae(latent, self, **decode_seedvr2_args, encode=False) - if x.ndim == 4: - # tiled_vae squeezes the temporal axis when - # temporal_downsample_factor == 1 AND latent T == 1 - # (see tiled_vae line 179-180); re-add it so the post-decode - # pipeline can keep batch and time distinct on the tiled path. - x = x.unsqueeze(2) - else: - x = super().decode_(latent) - - # ensure even dims for save video - h, w = x.shape[-2:] - w2 = w - (w % 2) - h2 = h - (h % 2) - x = x[..., :h2, :w2] - - return x - - def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float], memory_device = "same"): - set_norm_limit(norm_max_mem) - for m in self.modules(): - if isinstance(m, InflatedCausalConv3d): - m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) - - for module in self.modules(): - if isinstance(module, InflatedCausalConv3d): - module.set_memory_device(memory_device) diff --git a/comfy/model_base.py b/comfy/model_base.py index c084e23bb..042804771 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -54,8 +54,6 @@ import comfy.ldm.pixeldit.model import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 -import comfy.ldm.seedvr.model - import comfy.ldm.qwen_image.model import comfy.ldm.ideogram4.model import comfy.ldm.kandinsky5.model @@ -930,16 +928,6 @@ class HunyuanDiT(BaseModel): out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out -class SeedVR2(BaseModel): - def __init__(self, model_config, model_type=ModelType.FLOW, device=None): - super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) - def extra_conds(self, **kwargs): - out = super().extra_conds(**kwargs) - condition = kwargs.get("condition", None) - if condition is not None: - out["condition"] = comfy.conds.CONDRegular(condition) - return out - class PixArt(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 955581006..74c838d13 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -598,56 +598,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config - if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b - dit_config = {} - dit_config["image_model"] = "seedvr2" - dit_config["vid_dim"] = 3072 - dit_config["heads"] = 24 - dit_config["num_layers"] = 36 - # 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.`` - # submodules) at EVERY block — verified by inspecting the 7B - # state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means - # ``MMModule.shared_weights=False``). Native NaDiT computes - # per-block ``shared_weights = not (i < mm_layers)``, so to keep - # every block non-shared we set ``mm_layers = num_layers``. - # Without this, blocks at index >= mm_layers (default 10) try to - # load ``blocks.N.*.all.*`` keys that don't exist in the file, - # silently miss-load → all-black output. - dit_config["mm_layers"] = 36 - dit_config["norm_eps"] = 1e-5 - dit_config["qk_rope"] = True - dit_config["rope_type"] = "rope3d" - dit_config["rope_dim"] = 64 - dit_config["mlp_type"] = "normal" - return dit_config - elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b - dit_config = {} - dit_config["image_model"] = "seedvr2" - dit_config["vid_dim"] = 3072 - dit_config["heads"] = 24 - dit_config["num_layers"] = 36 - # This checkpoint layout carries shared ``all.`` MMModule keys. - # Preserve the historical split: the initial blocks use separate - # vid/txt modules, later blocks use shared modules. - dit_config["mm_layers"] = 10 - dit_config["norm_eps"] = 1e-5 - dit_config["qk_rope"] = True - dit_config["rope_type"] = "rope3d" - dit_config["rope_dim"] = 64 - dit_config["mlp_type"] = "swiglu" - return dit_config - elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b - dit_config = {} - dit_config["image_model"] = "seedvr2" - dit_config["vid_dim"] = 2560 - dit_config["heads"] = 20 - dit_config["num_layers"] = 32 - dit_config["norm_eps"] = 1.0e-05 - dit_config["qk_rope"] = None - dit_config["mlp_type"] = "swiglu" - dit_config["vid_out_norm"] = True - return dit_config - if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" diff --git a/comfy/sample.py b/comfy/sample.py index de71596b3..2be0cae5f 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -44,13 +44,7 @@ def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None, is_empty = torch.count_nonzero(latent_image) == 0 if is_empty: if latent_format.latent_channels != latent_image.shape[1]: - preserves_collapsed_channels = ( - getattr(latent_format, "preserve_empty_channel_multiples", False) - and latent_image.ndim == 4 - and latent_image.shape[1] % latent_format.latent_channels == 0 - ) - if not preserves_collapsed_channels: - latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) if downscale_ratio_spacial is not None: if downscale_ratio_spacial != latent_format.spacial_downscale_ratio: ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio diff --git a/comfy/sd.py b/comfy/sd.py index 8ac08ac42..a66ba1bfb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,4 +1,3 @@ -import inspect import json import torch from enum import Enum @@ -17,7 +16,6 @@ import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae -import comfy.ldm.seedvr.vae import comfy.ldm.triposplat.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae @@ -86,36 +84,6 @@ import comfy.latent_formats import comfy.ldm.flux.redux -SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL = 160 - - -def _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w): - output_t = max(1, (latent_t - 1) * 4 + 1) - return output_t * latent_h * 8 * latent_w * 8 - - -def _seedvr2_vae_decode_memory_used(shape): - if len(shape) == 5: - candidates = [] - if shape[1] == 16: - candidates.append((shape[2], shape[3], shape[4])) - if shape[-1] == 16: - candidates.append((shape[1], shape[2], shape[3])) - if len(candidates) == 0: - candidates.append((shape[2], shape[3], shape[4])) - output_pixels = max(_seedvr2_vae_decode_output_pixels(*candidate) for candidate in candidates) - elif len(shape) == 4: - latent_t = max(1, (shape[1] + 15) // 16) - latent_h, latent_w = shape[2], shape[3] - output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w) - else: - latent_t, latent_h, latent_w = 1, shape[-2], shape[-1] - output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w) - # SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels - # plus int64 sort indices dominate peak memory, not the VAE weight dtype. - return output_pixels * SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL - - def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None): key_map = {} if model is not None: @@ -499,10 +467,8 @@ class CLIP: class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): - is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd - if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - if metadata is None or metadata.get("keep_diffusers_format") != "true": - sd = diffusers_convert.convert_vae_state_dict(sd) + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + sd = diffusers_convert.convert_vae_state_dict(sd) if model_management.is_amd(): VAE_KL_MEM_RATIO = 2.73 @@ -574,20 +540,6 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 - elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 - self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() - self.latent_channels = 16 - self.latent_dim = 3 - self.disable_offload = True - self.memory_used_decode = lambda shape, dtype: _seedvr2_vae_decode_memory_used(shape) - self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype) - self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] - self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) - self.downscale_index_formula = (4, 8, 8) - self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) - self.upscale_index_formula = (4, 8, 8) - self.process_input = lambda image: image * 2.0 - 1.0 - self.crop_input = False elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} @@ -715,7 +667,6 @@ class VAE: self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] - elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] @@ -1055,40 +1006,6 @@ class VAE: decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) - def decode_tiled_seedvr2(self, samples, tile_x=32, tile_y=32, overlap=8, tile_t=16, overlap_t=4): - sf_s = getattr(self.first_stage_model, "spatial_downsample_factor", 8) - sf_t = getattr(self.first_stage_model, "temporal_downsample_factor", 4) - if tile_t is None: - tile_t = 16 - if overlap_t is None: - overlap_t = 4 - if tile_t > 0: - temporal_size = tile_t * sf_t - temporal_overlap = max(0, overlap_t) * sf_t - else: - temporal_size = 0 - temporal_overlap = 0 - args = { - "enable_tiling": True, - "tile_size": (tile_y * sf_s, tile_x * sf_s), - "tile_overlap": (overlap * sf_s, overlap * sf_s), - "temporal_size": temporal_size, - "temporal_overlap": temporal_overlap, - } - output = self.first_stage_model.decode( - samples.to(self.vae_dtype).to(self.device), - seedvr2_tiling=args, - ) - return self.process_output(output.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)) - - def _format_seedvr2_encoded_samples(self, samples): - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - if samples.ndim == 4: - samples = samples.unsqueeze(2) - samples = samples.contiguous() - samples = samples * 0.9152 - return samples - def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -1125,36 +1042,6 @@ class VAE: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) - def encode_tiled_seedvr2(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): - if tile_y is None: - tile_y = 512 - if tile_x is None: - tile_x = 512 - if overlap is None: - overlap_y = 64 - overlap_x = 64 - else: - overlap_y = overlap - overlap_x = overlap - if tile_t is None: - tile_t = 9999 - if overlap_t is None: - overlap_t = 0 - overlap_y = min(overlap_y, max(0, tile_y - 8)) - overlap_x = min(overlap_x, max(0, tile_x - 8)) - self.first_stage_model.device = self.device - x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device) - output = comfy.ldm.seedvr.vae.tiled_vae( - x, - self.first_stage_model, - tile_size=(tile_y, tile_x), - tile_overlap=(overlap_y, overlap_x), - temporal_size=tile_t, - temporal_overlap=overlap_t, - encode=True, - ) - return output.to(device=self.output_device, dtype=self.vae_output_dtype()) - def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None @@ -1202,40 +1089,16 @@ class VAE: if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) elif dims == 2: - # SeedVR2 latents arrive in 4D collapsed form ``(B, 16*T, H, W)`` - # downstream of ``SeedVR2Conditioning`` (which performs the - # ``rearrange(b c t h w -> b (c t) h w)`` collapse). The - # generic ``decode_tiled_`` would treat the channel dim as - # spatial-only and crash on the collapsed (16, T) layout - # under ``tiled_scale``'s mask broadcast; route SeedVR2 4D - # latents to ``decode_tiled_seedvr2`` instead, whose wrapper - # dispatch handles both 4D and 5D inputs. - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - tile = 256 // self.spacial_compression_decode() - overlap = tile // 4 - pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) - else: - pixel_samples = self.decode_tiled_(samples_in) + pixel_samples = self.decode_tiled_(samples_in) elif dims == 3: tile = 256 // self.spacial_compression_decode() overlap = tile // 4 - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) - else: - pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples - def decode_tiled( - self, - samples, - tile_x=None, - tile_y=None, - overlap=None, - tile_t=None, - overlap_t=None, - ): + def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -1249,20 +1112,7 @@ class VAE: args["overlap"] = overlap with model_management.cuda_device_context(self.device): - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper) and dims in (2, 3): - seedvr2_args = {} - if tile_x is not None: - seedvr2_args["tile_x"] = tile_x - if tile_y is not None: - seedvr2_args["tile_y"] = tile_y - if overlap is not None: - seedvr2_args["overlap"] = overlap - if tile_t is not None: - seedvr2_args["tile_t"] = tile_t - if overlap_t is not None: - seedvr2_args["overlap_t"] = overlap_t - output = self.decode_tiled_seedvr2(samples, **seedvr2_args) - elif dims == 1 or self.extra_1d_channel is not None: + if dims == 1 or self.extra_1d_channel is not None: args.pop("tile_y") output = self.decode_tiled_1d(samples, **args) elif dims == 2: @@ -1304,8 +1154,6 @@ class VAE: else: pixels_in = pixels_in.to(self.device) out = self.first_stage_model.encode(pixels_in) - if isinstance(out, tuple): - out = out[0] out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) @@ -1325,23 +1173,20 @@ class VAE: if self.latent_dim == 3: tile = 256 overlap = tile // 4 - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - samples = self.encode_tiled_seedvr2(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap) - else: - samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) elif self.latent_dim == 1 or self.extra_1d_channel is not None: samples = self.encode_tiled_1d(pixel_samples) else: samples = self.encode_tiled_(pixel_samples) - return self._format_seedvr2_encoded_samples(samples) + return samples def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) - if dims == 3 and pixel_samples.ndim < 5: + if dims == 3: if not self.not_video: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) else: @@ -1365,47 +1210,22 @@ class VAE: elif dims == 2: samples = self.encode_tiled_(pixel_samples, **args) elif dims == 3: - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - seedvr2_args = {} - if tile_x is not None: - seedvr2_args["tile_x"] = tile_x - else: - seedvr2_args["tile_x"] = 512 - if tile_y is not None: - seedvr2_args["tile_y"] = tile_y - else: - seedvr2_args["tile_y"] = 512 - if overlap is not None: - seedvr2_args["overlap"] = overlap - else: - seedvr2_args["overlap"] = 64 - if tile_t is not None: - seedvr2_args["tile_t"] = tile_t - else: - seedvr2_args["tile_t"] = 9999 - if overlap_t is not None: - seedvr2_args["overlap_t"] = overlap_t - else: - seedvr2_args["overlap_t"] = 0 - samples = self.encode_tiled_seedvr2(pixel_samples, **seedvr2_args) + if tile_t is not None: + tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) else: - if tile_t is not None: - tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) - else: - tile_t_latent = 9999 - args["tile_t"] = self.upscale_ratio[0](tile_t_latent) + tile_t_latent = 9999 + args["tile_t"] = self.upscale_ratio[0](tile_t_latent) - spatial_overlap = overlap if overlap is not None else 64 - if overlap_t is None: - args["overlap"] = (1, spatial_overlap, spatial_overlap) - else: - args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap) - maximum = pixel_samples.shape[2] - maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) + maximum = pixel_samples.shape[2] + maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) - samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) + samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) - return self._format_seedvr2_encoded_samples(samples) + return samples def get_sd(self): return self.first_stage_model.state_dict() @@ -1932,17 +1752,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) - -def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device): - set_dtype = model_config.set_inference_dtype - parameters = inspect.signature(set_dtype).parameters - supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()) - if supports_device: - set_dtype(dtype, manual_cast_dtype, device=device) - else: - set_dtype(dtype, manual_cast_dtype) - - def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) @@ -2050,7 +1859,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config.clip_vision_prefix is not None: if output_clipvision: @@ -2191,7 +2000,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if custom_operations is not None: model_config.custom_operations = custom_operations diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fa95003cc..7cf9c133b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1672,35 +1672,6 @@ class Chroma(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) -class SeedVR2(supported_models_base.BASE): - unet_config = { - "image_model": "seedvr2" - } - latent_format = comfy.latent_formats.SeedVR2 - - vae_key_prefix = ["vae."] - text_encoder_key_prefix = ["text_encoders."] - supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] - sampling_settings = { - "shift": 1.0, - } - - def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): - if ( - dtype == torch.float16 - and manual_cast_dtype is None - and comfy.model_management.should_use_bf16(device) - ): - manual_cast_dtype = torch.bfloat16 - super().set_inference_dtype(dtype, manual_cast_dtype, device=device) - - def get_model(self, state_dict, prefix="", device=None): - out = model_base.SeedVR2(self, device=device) - return out - - def clip_target(self, state_dict={}): - return None - class ChromaRadiance(Chroma): unet_config = { "image_model": "chroma_radiance", @@ -2058,6 +2029,7 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) + class RT_DETR_v4(supported_models_base.BASE): unet_config = { "image_model": "RT_DETR_v4", @@ -2295,7 +2267,6 @@ models = [ HiDream, HiDreamO1, Chroma, - SeedVR2, ChromaRadiance, ACEStep, ACEStep15, diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 572f9984e..0e7a829ba 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -115,7 +115,7 @@ class BASE: replace_prefix = {"": self.vae_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): + def set_inference_dtype(self, dtype, manual_cast_dtype): self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py deleted file mode 100644 index d5cd029ba..000000000 --- a/comfy_extras/nodes_seedvr.py +++ /dev/null @@ -1,1015 +0,0 @@ -from typing_extensions import override -from comfy_api.latest import ComfyExtension, io -import torch -import math -import logging -from einops import rearrange - -import gc -import comfy.model_management -import comfy.sample -import comfy.samplers -from comfy.ldm.seedvr.color_fix import ( - adain_color_transfer, - lab_color_transfer, - wavelet_color_transfer, -) -from comfy.ldm.seedvr.constants import ( - BYTEDANCE_IMG_SHIFT_FIT, - BYTEDANCE_SCHEDULE_T, - BYTEDANCE_VID_SHIFT_FIT, - SEEDVR2_ADAIN_SCALE_MULTIPLIER, - SEEDVR2_COLOR_MEM_HEADROOM, - SEEDVR2_COND_CHANNELS, - SEEDVR2_DTYPE_BYTES_FLOOR, - SEEDVR2_LAB_SCALE_MULTIPLIER, - SEEDVR2_LATENT_CHANNELS, - SEEDVR2_OOM_BACKOFF_DIVISOR, - SEEDVR2_WAVELET_SCALE_MULTIPLIER, -) - -from torchvision.transforms import functional as TVF -from torchvision.transforms import Lambda -from torchvision.transforms.functional import InterpolationMode - - -_SEEDVR2_INVALID_MODEL_MSG_PREFIX = ( - "SeedVR2Conditioning: model object does not match expected SeedVR2 structure" -) - -# Private sentinel for getattr default: distinguishes "attribute missing" -# from "attribute present but None" so the failure message is accurate. -_ATTR_MISSING = object() - - -def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk): - """Return stricter 4n+1 frame chunk sizes for auto OOM retries.""" - attempts = [frames_per_chunk] - current_chunk_latent = ( - t_latent if t_pixel <= frames_per_chunk - else (frames_per_chunk - 1) // 4 + 1 - ) - current_chunk_count = max(1, math.ceil(t_latent / current_chunk_latent)) - seen = {frames_per_chunk} - - for target_chunks in range(max(2, current_chunk_count + 1), t_latent + 1): - chunk_latent = max(1, math.ceil(t_latent / target_chunks)) - candidate = 4 * (chunk_latent - 1) + 1 - if candidate in seen: - continue - if candidate >= attempts[-1]: - continue - attempts.append(candidate) - seen.add(candidate) - - return attempts - - -def _resolve_seedvr2_diffusion_model(model): - """Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message.""" - inner = getattr(model, "model", _ATTR_MISSING) - if inner is _ATTR_MISSING: - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input has no 'model' attribute " - f"(got type {type(model).__name__})." - ) - if inner is None: - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input.model is None " - f"(input type {type(model).__name__})." - ) - diffusion_model = getattr(inner, "diffusion_model", _ATTR_MISSING) - if diffusion_model is _ATTR_MISSING: - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model' has no " - f"'diffusion_model' attribute (got type {type(inner).__name__})." - ) - if diffusion_model is None: - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model.diffusion_model' " - f"is None (model.model type {type(inner).__name__})." - ) - return diffusion_model - - -def _apply_rope_freqs_float32_cast(diffusion_model): - """Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype.""" - for module in diffusion_model.modules(): - if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'): - if module.rope.freqs.data.dtype != torch.float32: - module.rope.freqs.data = module.rope.freqs.data.to(torch.float32) - - -def clear_vae_memory(vae_model): - for module in vae_model.modules(): - if hasattr(module, "memory"): - module.memory = None - gc.collect() - comfy.model_management.soft_empty_cache() - -def expand_dims(tensor, ndim): - shape = tensor.shape + (1,) * (ndim - tensor.ndim) - return tensor.reshape(shape) - -def get_conditions(latent, latent_blur): - t, h, w, c = latent.shape - cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) - cond[:, ..., :-1] = latent_blur[:] - cond[:, ..., -1:] = 1.0 - return cond - -def timestep_transform(timesteps, latents_shapes): - vt = 4 - vs = 8 - frames = (latents_shapes[:, 0] - 1) * vt + 1 - heights = latents_shapes[:, 1] * vs - widths = latents_shapes[:, 2] * vs - - # Compute shift factor. - def get_lin_function(x1, y1, x2, y2): - m = (y2 - y1) / (x2 - x1) - b = y1 - m * x1 - return lambda x: m * x + b - - img_shift_fn = get_lin_function(*BYTEDANCE_IMG_SHIFT_FIT) - vid_shift_fn = get_lin_function(*BYTEDANCE_VID_SHIFT_FIT) - shift = torch.where( - frames > 1, - vid_shift_fn(heights * widths * frames), - img_shift_fn(heights * widths), - ).to(timesteps.device) - - # Shift timesteps. - T = BYTEDANCE_SCHEDULE_T - timesteps = timesteps / T - timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) - timesteps = timesteps * T - return timesteps - -def inter(x_0, x_T, t): - t = expand_dims(t, x_0.ndim) - T = BYTEDANCE_SCHEDULE_T - B = lambda t: t / T - A = lambda t: 1 - (t / T) - return A(t) * x_0 + B(t) * x_T - -def div_pad(image, factor): - - height_factor, width_factor = factor - height, width = image.shape[-2:] - - pad_height = (height_factor - (height % height_factor)) % height_factor - pad_width = (width_factor - (width % width_factor)) % width_factor - - if pad_height == 0 and pad_width == 0: - return image - - if isinstance(image, torch.Tensor): - padding = (0, pad_width, 0, pad_height) - image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0) - - return image - -def cut_videos(videos): - t = videos.size(1) - if t == 1: - return videos - if t <= 4 : - padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - return videos - if (t - 1) % (4) == 0: - return videos - else: - padding = [videos[:, -1].unsqueeze(1)] * ( - 4 - ((t - 1) % (4)) - ) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - assert (videos.size(1) - 1) % (4) == 0 - return videos - -def _seedvr2_input_shorter_edge(images, node_name): - if images.dim() == 4: - return min(images.shape[1], images.shape[2]) - if images.dim() == 5: - return min(images.shape[2], images.shape[3]) - raise ValueError( - f"{node_name}: expected 4-D or 5-D IMAGE tensor, " - f"got shape {tuple(images.shape)}" - ) - - -def _seedvr2_pad(images, upscaled_shorter_edge, node_name): - if upscaled_shorter_edge < 2: - raise ValueError( - f"{node_name}: input shorter edge must be at least 2 pixels; " - f"got {upscaled_shorter_edge}." - ) - if images.shape[-1] > 3: - images = images[..., :3] - if images.dim() == 4: - # Comfy video components arrive as a 4-D IMAGE frame sequence: - # (frames, H, W, C). SeedVR2 consumes that as one video. - images = images.unsqueeze(0) - elif images.dim() != 5: - raise ValueError( - f"{node_name}: expected 4-D or 5-D IMAGE tensor, " - f"got shape {tuple(images.shape)}" - ) - images = images.permute(0, 1, 4, 2, 3) - - b, t, c, h, w = images.shape - images = images.reshape(b * t, c, h, w) - - clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) - images = clip(images) - images = div_pad(images, (16, 16)) - _, _, new_h, new_w = images.shape - - images = images.reshape(b, t, c, new_h, new_w) - images = cut_videos(images) - images_bthwc = rearrange(images, "b t c h w -> b t h w c") - - return io.NodeOutput(images_bthwc) - - -class SeedVR2Preprocess(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2Preprocess", - display_name="Pre-Process SeedVR2 Input", - category="image/upscaling", - description="Pad a resized image for SeedVR2 model. Alpha channel is dropped. The node Post-Process SeedVR2 Output re-applies it from the original resized image.", - inputs=[ - io.Image.Input("resized_images", tooltip="The resized image to process."), - ], - outputs=[ - io.Image.Output("images"), - ] - ) - - @classmethod - def execute(cls, resized_images): - upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess") - return _seedvr2_pad( - resized_images, upscaled_shorter_edge, "SeedVR2Preprocess", - ) - - -class SeedVR2PostProcessing(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2PostProcessing", - display_name="Post-Process SeedVR2 Output", - category="image/upscaling", - description="Align the generated image with the original resized image and apply color correction.", - inputs=[ - io.Image.Input("images", tooltip="The generated image to process."), - io.Image.Input("original_resized_images", tooltip="The original resized image before pre-processing, used as reference."), - io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="Method to match the generated image colors to the original image. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."), - ], - outputs=[io.Image.Output(display_name="images")], - ) - - @classmethod - def execute(cls, images, original_resized_images, color_correction_method): - alpha_input = None - if original_resized_images.shape[-1] == 4: - alpha_input = original_resized_images[..., 3:4] - original_resized_images = original_resized_images[..., :3] - decoded_5d, decoded_was_4d = cls._as_bthwc(images) - reference_full, _ = cls._as_bthwc(original_resized_images) - decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full) - - b = min(decoded_5d.shape[0], reference_full.shape[0]) - t = min(decoded_5d.shape[1], reference_full.shape[1]) - reference_h = reference_full.shape[2] - reference_w = reference_full.shape[3] - - decoded_5d = decoded_5d[:b, :t, :, :, :] - target_h = min(decoded_5d.shape[2], reference_h) - target_w = min(decoded_5d.shape[3], reference_w) - decoded_5d = decoded_5d[:, :, :target_h, :target_w, :] - if color_correction_method in ("lab", "wavelet", "adain"): - reference_5d = reference_full[:b, :t, :, :, :] - reference_5d = cls._resize_reference(reference_5d, target_h, target_w) - output_device = decoded_5d.device - decoded_raw = cls._to_seedvr2_raw(decoded_5d) - reference_raw = cls._to_seedvr2_raw(reference_5d) - decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w") - reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w") - output = cls._color_transfer_chunked( - decoded_flat, reference_flat, output_device, color_correction_method, - ) - output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t) - output = output.add(1.0).div(2.0).clamp(0.0, 1.0) - elif color_correction_method == "none": - output = decoded_5d - else: - raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") - - if alpha_input is not None: - alpha_5d, _ = cls._as_bthwc(alpha_input) - alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :] - output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1) - h2 = output.shape[-3] - (output.shape[-3] % 2) - w2 = output.shape[-2] - (output.shape[-2] % 2) - output = output[:, :, :h2, :w2, :] - if decoded_was_4d: - output = output.reshape(-1, output.shape[-3], output.shape[-2], output.shape[-1]) - return io.NodeOutput(output) - - @staticmethod - def _as_bthwc(images): - if images.ndim == 4: - return images.unsqueeze(0), True - if images.ndim == 5: - return images, False - raise ValueError( - f"SeedVR2PostProcessing: expected 4-D or 5-D IMAGE tensor, got shape {tuple(images.shape)}" - ) - - @staticmethod - def _restore_reference_batch_time(decoded, reference): - if decoded.shape[0] != 1: - return decoded - ref_b, ref_t = reference.shape[:2] - if ref_b < 1 or decoded.shape[1] % ref_b != 0: - return decoded - decoded_t = decoded.shape[1] // ref_b - if decoded_t < ref_t: - return decoded - return decoded.reshape(ref_b, decoded_t, decoded.shape[2], decoded.shape[3], decoded.shape[4]) - - @staticmethod - def _to_seedvr2_raw(images): - return images.mul(2.0).sub(1.0) - - @staticmethod - def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn): - color_device = comfy.model_management.vae_device() - decoded_flat = decoded_flat.to(device=color_device) - reference_flat = reference_flat.to(device=color_device) - output = transfer_fn(decoded_flat, reference_flat) - return output.to(device=output_device) - - @staticmethod - def _lab_color_transfer_on_vae_device(decoded_flat, reference_flat, output_device): - color_device = comfy.model_management.vae_device() - result = None - for start in range(decoded_flat.shape[0]): - decoded_frame = decoded_flat[start:start + 1].to(device=color_device).clone() - reference_frame = reference_flat[start:start + 1].to(device=color_device).clone() - output = lab_color_transfer(decoded_frame, reference_frame).to(device=output_device) - if result is None: - result = torch.empty( - (decoded_flat.shape[0],) + tuple(output.shape[1:]), - device=output_device, - dtype=output.dtype, - ) - result[start:start + 1].copy_(output) - if result is None: - raise ValueError("SeedVR2PostProcessing: LAB color correction requires at least one frame.") - return result - - @classmethod - def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method): - chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method) - while True: - next_chunk_size = None - try: - return cls._run_color_transfer_chunks( - decoded_flat, reference_flat, output_device, color_correction_method, chunk_size, - ) - except Exception as e: - comfy.model_management.raise_non_oom(e) - if chunk_size <= 1: - raise RuntimeError( - "SeedVR2PostProcessing: color correction OOM at one frame; " - f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}." - ) from e - next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) - - comfy.model_management.soft_empty_cache() - chunk_size = next_chunk_size - - @classmethod - def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size): - result = None - for start in range(0, decoded_flat.shape[0], chunk_size): - end = min(start + chunk_size, decoded_flat.shape[0]) - decoded_chunk = decoded_flat[start:end] - reference_chunk = reference_flat[start:end] - if color_correction_method == "lab": - output = cls._lab_color_transfer_on_vae_device(decoded_chunk, reference_chunk, output_device) - elif color_correction_method == "wavelet": - output = cls._color_transfer_on_vae_device( - decoded_chunk, reference_chunk, output_device, wavelet_color_transfer, - ) - else: - output = cls._color_transfer_on_vae_device( - decoded_chunk, reference_chunk, output_device, adain_color_transfer, - ) - if result is None: - result = torch.empty( - (decoded_flat.shape[0],) + tuple(output.shape[1:]), - device=output_device, - dtype=output.dtype, - ) - result[start:end].copy_(output) - if result is None: - raise ValueError("SeedVR2PostProcessing: color correction requires at least one frame.") - return result - - @classmethod - def _estimate_color_correction_chunk_size(cls, decoded_flat, color_correction_method): - multiplier = cls._color_correction_memory_multiplier(color_correction_method) - frames = decoded_flat.shape[0] - _, channels, height, width = decoded_flat.shape - dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR) - bytes_per_frame = height * width * channels * dtype_bytes * multiplier - if bytes_per_frame <= 0: - return frames - color_device = comfy.model_management.vae_device() - free_memory = comfy.model_management.get_free_memory(color_device) - chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame) - return max(1, min(frames, chunk_size)) - - @staticmethod - def _color_correction_memory_multiplier(color_correction_method): - if color_correction_method == "lab": - return SEEDVR2_LAB_SCALE_MULTIPLIER - if color_correction_method == "wavelet": - return SEEDVR2_WAVELET_SCALE_MULTIPLIER - if color_correction_method == "adain": - return SEEDVR2_ADAIN_SCALE_MULTIPLIER - raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") - - @staticmethod - def _resize_reference(reference, height, width): - if reference.shape[2] == height and reference.shape[3] == width: - return reference - b, t = reference.shape[:2] - reference_flat = rearrange(reference, "b t h w c -> (b t) c h w") - resized = TVF.resize( - reference_flat, - size=(height, width), - interpolation=InterpolationMode.BICUBIC, - antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"), - ) - return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t) - - -class SeedVR2Conditioning(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2Conditioning", - display_name="Apply SeedVR2 Conditioning", - category="conditioning", - description="Build SeedVR2 positive/negative conditioning from a VAE latent.", - inputs=[ - io.Model.Input("model", tooltip="The SeedVR2 model."), - io.Latent.Input("vae_conditioning", display_name="latent"), - ], - outputs=[ - io.Model.Output(display_name = "model"), - io.Conditioning.Output(display_name = "positive"), - io.Conditioning.Output(display_name = "negative"), - io.Latent.Output(display_name = "latent"), - ], - ) - - @classmethod - def execute(cls, model, vae_conditioning) -> io.NodeOutput: - - vae_conditioning = vae_conditioning["samples"] - if vae_conditioning.ndim != 5: - raise ValueError( - "SeedVR2Conditioning expects a 5-D VAE latent in Comfy " - f"channel-first layout; got shape {tuple(vae_conditioning.shape)}." - ) - if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS: - raise ValueError( - "SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy " - f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " - f"got channel-last shape {tuple(vae_conditioning.shape)}." - ) - vae_conditioning = vae_conditioning.movedim(1, -1).contiguous() - model_patcher = model - model = _resolve_seedvr2_diffusion_model(model_patcher) - pos_cond = model.positive_conditioning - neg_cond = model.negative_conditioning - - # Fail-loud guard against silently-wrong output when a - # DiT-only ``.safetensors`` (no ``positive_conditioning`` / - # ``negative_conditioning`` keys) is loaded via ``UNETLoader``. - # ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see - # ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)`` - # leaves them at zero when the keys are absent. Detect that state - # here rather than at ``BaseModel.extra_conds`` (per sampling step, - # wasteful) or at the resolver helper (mixes structural shape with - # semantic content). Both buffers must be checked together — partial - # bake regressions could populate one but not the other. - if ( - pos_cond.float().abs().sum().item() == 0 - and neg_cond.float().abs().sum().item() == 0 - ): - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning " - f"and negative_conditioning buffers are zero-valued — model " - f"file appears to be a DiT-only export missing " - f"the SeedVR2 conditioning tensors. " - f"Re-bake the file with ``positive_conditioning`` (58, 5120) " - f"and ``negative_conditioning`` (64, 5120) keys at top level, " - f"or load via CheckpointLoaderSimple from a bundled " - f"checkpoint." - ) - - _apply_rope_freqs_float32_cast(model) - - condition = torch.stack([get_conditions(c, c) for c in vae_conditioning]) - condition = condition.movedim(-1, 1) - latent = vae_conditioning.movedim(-1, 1) - - latent = rearrange(latent, "b c t h w -> b (c t) h w") - condition = rearrange(condition, "b c t h w -> b (c t) h w") - - negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] - positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] - - return io.NodeOutput(model_patcher, positive, negative, {"samples": latent}) - -def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int, - t_end: int, channels: int) -> torch.Tensor: - """Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse.""" - B, CT, H, W = tensor_4d.shape - if CT % channels != 0: - raise ValueError( - f"_slice_collapsed_4d_along_t: collapsed channel dim {CT} is not " - f"divisible by channels={channels}; tensor shape {tuple(tensor_4d.shape)}." - ) - T = CT // channels - if not (0 <= t_start < t_end <= T): - raise ValueError( - f"_slice_collapsed_4d_along_t: slice [{t_start}:{t_end}] out of " - f"range for T={T}." - ) - new_T = t_end - t_start - sliced = tensor_4d.reshape(B, channels, T, H, W)[:, :, t_start:t_end, :, :].contiguous() - return sliced.reshape(B, channels * new_T, H, W) - - -def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int): - """Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated.""" - new_list = [] - for entry in cond_list: - text_cond, options = entry[0], entry[1] - if "condition" not in options: - new_list.append(entry) - continue - new_options = options.copy() - new_options["condition"] = _slice_collapsed_4d_along_t( - new_options["condition"], t_start, t_end, - SEEDVR2_COND_CHANNELS, - ) - new_list.append([text_cond, new_options]) - return new_list - - -def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor, - samples_4d: torch.Tensor, - t_start: int, - t_end: int): - """Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand.""" - if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]: - return _slice_collapsed_4d_along_t( - noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS, - ) - return noise_mask - - -def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor: - """Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D.""" - if len(chunks_4d) == 0: - raise ValueError("_concat_chunks_along_t: empty chunk list.") - fives = [] - for ch in chunks_4d: - B, CT, H, W = ch.shape - if CT % channels != 0: - raise ValueError( - f"_concat_chunks_along_t: chunk shape {tuple(ch.shape)} " - f"channel dim {CT} not divisible by channels={channels}." - ) - T = CT // channels - fives.append(ch.reshape(B, channels, T, H, W)) - cat = torch.cat(fives, dim=2).contiguous() - B, C, T_total, H, W = cat.shape - return cat.reshape(B, C * T_total, H, W) - - -def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor: - """1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``): - Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3`` - (dead-band would collapse a tiny transition). Window shape matched to the reference - overlapping-frame blend for parity; caller broadcasts across ``(B, C, T_overlap, H, W)``. - """ - if overlap < 1: - raise ValueError( - f"_hann_blend_weights_1d: overlap must be >= 1; got {overlap}." - ) - if overlap >= 3: - t = torch.linspace(0.0, 1.0, steps=overlap, device=device, dtype=dtype) - blend_start = 1.0 / 3.0 - blend_end = 2.0 / 3.0 - u = ((t - blend_start) / (blend_end - blend_start)).clamp(0.0, 1.0) - return 0.5 + 0.5 * torch.cos(torch.pi * u) - return torch.linspace(1.0, 0.0, steps=overlap, device=device, dtype=dtype) - - -def _blend_overlap_region(prev_tail_5d: torch.Tensor, - cur_head_5d: torch.Tensor) -> torch.Tensor: - """Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device).""" - if prev_tail_5d.shape != cur_head_5d.shape: - raise ValueError( - f"_blend_overlap_region: shape mismatch " - f"prev {tuple(prev_tail_5d.shape)} vs " - f"cur {tuple(cur_head_5d.shape)}." - ) - overlap = int(prev_tail_5d.shape[2]) - w_prev_1d = _hann_blend_weights_1d( - overlap, prev_tail_5d.device, prev_tail_5d.dtype, - ) - # Reshape to (1, 1, overlap, 1, 1) for broadcast across B, C, H, W. - w_prev = w_prev_1d.view(1, 1, overlap, 1, 1) - w_cur = 1.0 - w_prev - return prev_tail_5d * w_prev + cur_head_5d * w_cur - - -def _concat_chunks_with_overlap_blend(chunk_specs, channels: int, - overlap_latent: int) -> torch.Tensor: - """Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk.""" - if len(chunk_specs) == 0: - raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.") - if overlap_latent < 0: - raise ValueError( - f"_concat_chunks_with_overlap_blend: overlap_latent must be " - f">= 0; got {overlap_latent}." - ) - - # Validate channel divisibility once and capture per-chunk T. - chunk_5d = [] - for t_start, t_end, ch in chunk_specs: - B, CT, H, W = ch.shape - if CT % channels != 0: - raise ValueError( - f"_concat_chunks_with_overlap_blend: chunk shape " - f"{tuple(ch.shape)} channel dim {CT} not divisible " - f"by channels={channels}." - ) - T = CT // channels - if t_end - t_start != T: - raise ValueError( - f"_concat_chunks_with_overlap_blend: chunk T={T} mismatches " - f"declared range [{t_start}:{t_end}]." - ) - chunk_5d.append((t_start, t_end, ch.reshape(B, channels, T, H, W))) - - if overlap_latent == 0: - # Fast path: pure concat in the caller-provided chunk order. - return _concat_chunks_along_t( - [c.reshape(c.shape[0], channels * c.shape[2], c.shape[3], c.shape[4]) - for _, _, c in chunk_5d], - channels, - ) - - T_total = max(t_end for _, t_end, _ in chunk_5d) - first_5d = chunk_5d[0][2] - B = first_5d.shape[0] - H = first_5d.shape[3] - W = first_5d.shape[4] - result = torch.empty( - (B, channels, T_total, H, W), - device=first_5d.device, dtype=first_5d.dtype, - ) - filled_until = 0 - for i, (cs, ce, ct_5d) in enumerate(chunk_5d): - chunk_T = int(ct_5d.shape[2]) - if i == 0: - result[:, :, cs:ce, :, :] = ct_5d - filled_until = ce - continue - # Overlap region width is bounded by both the previous fill - # frontier and the current chunk's actual length (for runt - # final chunks shorter than the configured overlap). - overlap_len = min(filled_until - cs, chunk_T) - if overlap_len > 0: - prev_tail = result[:, :, cs:cs + overlap_len, :, :].contiguous() - cur_head = ct_5d[:, :, :overlap_len, :, :].contiguous() - blended = _blend_overlap_region(prev_tail, cur_head) - result[:, :, cs:cs + overlap_len, :, :] = blended - tail_start = cs + overlap_len - tail_end = ce - if tail_end > tail_start: - result[:, :, tail_start:tail_end, :, :] = ( - ct_5d[:, :, overlap_len:, :, :] - ) - else: - # Disjoint chunks (overlap_latent set but this pair did not - # actually overlap, e.g. step_latent equal to chunk_latent - # in a degenerate config). Treat as concat. - result[:, :, cs:ce, :, :] = ct_5d - filled_until = ce - - return result.contiguous().reshape(B, channels * T_total, H, W) - - -def _run_standard_sample(model, seed: int, steps: int, cfg: float, - sampler_name: str, scheduler: str, - positive, negative, latent: dict, - denoise: float) -> dict: - """Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk.""" - samples_in = latent["samples"] - samples_in = comfy.sample.fix_empty_latent_channels( - model, samples_in, latent.get("downscale_ratio_spacial", None), - ) - batch_inds = latent.get("batch_index", None) - noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds) - noise_mask = latent.get("noise_mask", None) - samples = comfy.sample.sample( - model, noise, steps, cfg, sampler_name, scheduler, - positive, negative, samples_in, - denoise=denoise, noise_mask=noise_mask, seed=seed, - ) - out = latent.copy() - out.pop("downscale_ratio_spacial", None) - out["samples"] = samples - return out - - -class SeedVR2ProgressiveSampler(io.ComfyNode): - """Sequential temporal chunking sampler for SeedVR2 native. - - Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that - OOM on long sequences. The latent enters the sampler in SeedVR2's - collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning`` - at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that - tensor along the temporal axis, runs the configured inner sampler - sequentially per chunk against the standard ``comfy.sample.sample`` - entry point, and concatenates per-chunk outputs back into a single - ``(B, 16*T_total, H, W)`` latent. - - ``frames_per_chunk`` is expressed in pixel-frame units to match the - SeedVR2 4n+1 constraint enforced upstream by ``cut_videos`` and the - VAE's ``temporal_downsample_factor=4``. A pixel chunk size ``F`` - maps to ``(F - 1) // 4 + 1`` latent-frame chunks. - - Determinism contract: a single noise tensor is generated once from - the user seed and sliced per chunk (rather than re-seeding each - chunk), so a workflow that fits in a single chunk produces output - identical to a workflow that fits in N chunks at the same seed, - modulo the inherent T-axis chunk-boundary independence of the model. - """ - - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2ProgressiveSampler", - display_name="Sample SeedVR2 (Progressive)", - category="sampling", - description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.", - inputs=[ - io.Model.Input("model", tooltip="The model used for denoising the input latent."), - io.Int.Input("seed", default=0, min=0, - max=0xffffffffffffffff, - control_after_generate=True, - tooltip="The random seed used for creating the noise."), - io.Int.Input("steps", default=20, min=1, max=10000, - tooltip="The number of steps used in the denoising process."), - io.Float.Input("cfg", default=1.0, min=0.0, max=100.0, - step=0.1, round=0.01, - tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."), - io.Combo.Input("sampler_name", - options=comfy.samplers.SAMPLER_NAMES, - tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."), - io.Combo.Input("scheduler", - options=comfy.samplers.SCHEDULER_NAMES, - tooltip="The scheduler controls how noise is gradually removed to form the image."), - io.Conditioning.Input("positive", - tooltip="The conditioning describing the attributes you want to include in the image."), - io.Conditioning.Input("negative", - tooltip="The conditioning describing the attributes you want to exclude from the image."), - io.Latent.Input("latent", - tooltip="The latent image to denoise."), - io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, - step=0.01, - tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."), - io.Int.Input("frames_per_chunk", default=21, min=1, - max=16384, step=4, - tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."), - io.Int.Input("temporal_overlap", default=0, min=0, - max=16384, - tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."), - io.Combo.Input("chunking_mode", - options=["manual", "auto"], - default="manual", - tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."), - ], - outputs=[io.Latent.Output(display_name="latent")], - ) - - @classmethod - def execute(cls, model, seed, steps, cfg, sampler_name, scheduler, - positive, negative, latent, denoise, - frames_per_chunk, temporal_overlap, - chunking_mode="manual") -> io.NodeOutput: - # 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline - # requires pixel-frame counts of the form 4n+1 (1, 5, 9, 13, ...), - # imposed at ``cut_videos`` upstream and propagated through the VAE's - # temporal_downsample_factor=4. Reject violations explicitly before - # any model invocation; a silent rounding would mis-align chunk - # boundaries with the 4n+1 lattice. - if frames_per_chunk < 1 or (frames_per_chunk - 1) % 4 != 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: frames_per_chunk must be a " - f"4n+1 pixel-frame count (1, 5, 9, 13, 17, 21, ...); " - f"got {frames_per_chunk}." - ) - - samples_4d = latent["samples"] - samples_4d = comfy.sample.fix_empty_latent_channels( - model, samples_4d, - latent.get("downscale_ratio_spacial", None), - ) - if samples_4d.ndim != 4: - raise ValueError( - f"SeedVR2ProgressiveSampler: expected 4D collapsed latent " - f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}." - ) - B, CT, H, W = samples_4d.shape - if CT % SEEDVR2_LATENT_CHANNELS != 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is " - f"not divisible by SeedVR2 latent channels " - f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be " - f"SeedVR2-shaped." - ) - T_latent = CT // SEEDVR2_LATENT_CHANNELS - T_pixel = 4 * (T_latent - 1) + 1 - - if chunking_mode not in ("manual", "auto"): - raise ValueError( - f"SeedVR2ProgressiveSampler: chunking_mode must be " - f"'manual' or 'auto'; got {chunking_mode!r}." - ) - - if chunking_mode == "auto": - attempts = _seedvr2_auto_chunk_attempts( - T_latent, T_pixel, frames_per_chunk, - ) - for i, attempt_frames_per_chunk in enumerate(attempts): - retry = False - try: - return cls.execute( - model=model, seed=seed, steps=steps, cfg=cfg, - sampler_name=sampler_name, scheduler=scheduler, - positive=positive, negative=negative, - latent=latent, denoise=denoise, - frames_per_chunk=attempt_frames_per_chunk, - temporal_overlap=temporal_overlap, - chunking_mode="manual", - ) - except Exception as e: - comfy.model_management.raise_non_oom(e) - if i == len(attempts) - 1: - raise RuntimeError( - "SeedVR2ProgressiveSampler: exhausted auto " - "chunking attempts after OOM. Tried " - f"frames_per_chunk values {attempts}." - ) from e - retry = True - - if retry: - logging.warning( - "SeedVR2ProgressiveSampler auto chunking OOM at " - "frames_per_chunk=%s; retrying with " - "frames_per_chunk=%s.", - attempt_frames_per_chunk, attempts[i + 1], - ) - comfy.model_management.soft_empty_cache() - - # Short-circuit: total fits in one chunk -> standard path with no - # chunking overhead. Output of this branch is byte-identical to the - # built-in KSampler given the same (model, seed, steps, cfg, - # sampler_name, scheduler, positive, negative, latent, - # denoise) tuple. - if T_pixel <= frames_per_chunk: - return io.NodeOutput(_run_standard_sample( - model, seed, steps, cfg, sampler_name, scheduler, - positive, negative, latent, denoise, - )) - - # Map pixel chunk -> latent chunk. Each chunk's latent length is - # at most ``chunk_latent``; the final chunk may be a runt that - # is automatically 4n+1-aligned in the pixel domain by the - # T_pixel = 4*(T_latent-1) + 1 mapping (every positive integer - # T_latent corresponds to a valid 4n+1 pixel count). - chunk_latent = (frames_per_chunk - 1) // 4 + 1 - - # ``temporal_overlap`` is exposed in latent-frame units, but users - # do not know the derived latent chunk length. Treat oversized - # values as "maximum valid overlap" while preserving a strictly - # positive chunk-loop stride. - if temporal_overlap < 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: temporal_overlap must be >= 0; " - f"got {temporal_overlap}." - ) - temporal_overlap = min(temporal_overlap, chunk_latent - 1) - step_latent = chunk_latent - temporal_overlap - - # Generate full noise once from the user seed, then slice along T - # per chunk. Using one global noise tensor (rather than re-seeding - # per chunk) preserves seed-determinism across chunk-count - # variations: the same (seed, total T_latent) always produces the - # same noise samples regardless of how the work is partitioned. - batch_inds = latent.get("batch_index", None) - noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds) - - noise_mask = latent.get("noise_mask", None) - - # Build the flat list of chunk ranges first so the chunking - # geometry is fully known before any sample call. - chunk_ranges = [] - for chunk_start in range(0, T_latent, step_latent): - chunk_end = min(chunk_start + chunk_latent, T_latent) - if chunk_start >= chunk_end: - # The final iteration of a stride that lands exactly on - # T_latent produces a zero-length chunk; skip it. - break - chunk_ranges.append((chunk_start, chunk_end)) - if chunk_end >= T_latent: - break - - def _sample_one_chunk(chunk_start, chunk_end): - samples_chunk = _slice_collapsed_4d_along_t( - samples_4d, chunk_start, chunk_end, - SEEDVR2_LATENT_CHANNELS, - ) - noise_chunk = _slice_collapsed_4d_along_t( - noise_full, chunk_start, chunk_end, - SEEDVR2_LATENT_CHANNELS, - ) - positive_chunk = _slice_seedvr2_cond_along_t( - positive, chunk_start, chunk_end, - ) - negative_chunk = _slice_seedvr2_cond_along_t( - negative, chunk_start, chunk_end, - ) - - # Per-chunk noise_mask handling: standard masks are passed - # through for KSampler expansion; pre-expanded collapsed - # masks are sliced. - chunk_noise_mask = None - if noise_mask is not None: - chunk_noise_mask = _slice_seedvr2_noise_mask_along_t( - noise_mask, samples_4d, chunk_start, chunk_end, - ) - - return comfy.sample.sample( - model, noise_chunk, steps, cfg, sampler_name, scheduler, - positive_chunk, negative_chunk, samples_chunk, - denoise=denoise, noise_mask=chunk_noise_mask, seed=seed, - ) - - chunk_specs = [] - for chunk_start, chunk_end in chunk_ranges: - chunk_samples = _sample_one_chunk(chunk_start, chunk_end) - chunk_specs.append((chunk_start, chunk_end, chunk_samples)) - - final = _concat_chunks_with_overlap_blend( - chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap, - ) - - out = latent.copy() - out.pop("downscale_ratio_spacial", None) - out["samples"] = final - return io.NodeOutput(out) - - -class SeedVRExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[io.ComfyNode]]: - return [ - SeedVR2Conditioning, - SeedVR2Preprocess, - SeedVR2PostProcessing, - SeedVR2ProgressiveSampler, - ] - -async def comfy_entrypoint() -> SeedVRExtension: - return SeedVRExtension() diff --git a/nodes.py b/nodes.py index d9ac53ede..2f5a478b5 100644 --- a/nodes.py +++ b/nodes.py @@ -47,18 +47,14 @@ import node_helpers if args.enable_manager: import comfyui_manager - def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() - def interrupt_processing(value=True): comfy.model_management.interrupt_current_processing(value) - MAX_RESOLUTION=16384 - class CLIPTextEncode(ComfyNodeABC): @classmethod def INPUT_TYPES(s) -> InputTypeDict: @@ -327,8 +323,8 @@ class VAEDecodeTiled: return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32, "advanced": True}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}), - "temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}), - "temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time.", "advanced": True}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" @@ -338,32 +334,18 @@ class VAEDecodeTiled: def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): if tile_size < overlap * 4: overlap = tile_size // 4 + if temporal_size < temporal_overlap * 2: + temporal_overlap = temporal_overlap // 2 temporal_compression = vae.temporal_compression_decode() if temporal_compression is not None: - if temporal_size <= 0: - temporal_size = 0 - temporal_overlap = 0 - else: - requested_temporal_overlap = temporal_overlap - if temporal_size < temporal_overlap * 2: - temporal_overlap = temporal_overlap // 2 - temporal_size = max(2, temporal_size // temporal_compression) - temporal_overlap = min(temporal_size // 2, temporal_overlap // temporal_compression) - if requested_temporal_overlap > 0: - temporal_overlap = max(1, temporal_overlap) + temporal_size = max(2, temporal_size // temporal_compression) + temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression)) else: temporal_size = None temporal_overlap = None compression = vae.spacial_compression_decode() - images = vae.decode_tiled( - samples["samples"], - tile_x=tile_size // compression, - tile_y=tile_size // compression, - overlap=overlap // compression, - tile_t=temporal_size, - overlap_t=temporal_overlap, - ) + images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) @@ -380,7 +362,7 @@ class VAEEncode: def encode(self, vae, pixels): t = vae.encode(pixels) - return ({"samples": t}, ) + return ({"samples":t}, ) class VAEEncodeTiled: @classmethod @@ -388,8 +370,8 @@ class VAEEncodeTiled: return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64, "advanced": True}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}), - "temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}), - "temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time.", "advanced": True}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" @@ -397,9 +379,6 @@ class VAEEncodeTiled: CATEGORY = "experimental" def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): - if temporal_size <= 0: - temporal_size = 0 - temporal_overlap = 0 t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) return ({"samples": t}, ) @@ -2439,7 +2418,6 @@ async def init_builtin_extra_nodes(): "nodes_camera_trajectory.py", "nodes_edit_model.py", "nodes_tcfg.py", - "nodes_seedvr.py", "nodes_context_windows.py", "nodes_qwen.py", "nodes_chroma_radiance.py", diff --git a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py deleted file mode 100644 index 2a6e3d430..000000000 --- a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Consolidated SeedVR2 conditioning and refactor regression tests. - -Merges the prior test_seedvr2_refactor_nodes.py and -test_seedvr_conditioning_hardening.py modules. Refactor tests use the -top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests -use _import_nodes_seedvr_isolated() for sys.modules isolation when -mocking comfy.model_management. -""" - -import importlib -import sys -from unittest.mock import MagicMock - -import pytest -import torch -import torch.nn as nn - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - - -_SENTINEL = object() -_TARGETS = ( - ("comfy.model_management", "comfy"), - ("comfy_extras.nodes_seedvr", "comfy_extras"), -) - - -def _import_nodes_seedvr_isolated(): - """Import comfy_extras.nodes_seedvr with comfy.model_management mocked.""" - priors = [] - for mod_name, parent_name in _TARGETS: - prior_mod = sys.modules.get(mod_name, _SENTINEL) - parent = sys.modules.get(parent_name) - attr = mod_name.split(".")[-1] - prior_attr = ( - getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL - ) - priors.append((mod_name, parent_name, attr, prior_mod, prior_attr)) - - mock_mm = MagicMock() - for fn in ( - "xformers_enabled", "xformers_enabled_vae", - "pytorch_attention_enabled", "pytorch_attention_enabled_vae", - "sage_attention_enabled", "flash_attention_enabled", - "is_intel_xpu", - ): - getattr(mock_mm, fn).return_value = False - tv = torch.version.__version__.split(".") - mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1])) - mock_mm.WINDOWS = False - sys.modules["comfy.model_management"] = mock_mm - if sys.modules.get("comfy") is None: - import comfy as _comfy_pkg # noqa: F401 - comfy_pkg = sys.modules.get("comfy") - if comfy_pkg is not None: - setattr(comfy_pkg, "model_management", mock_mm) - nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or ( - importlib.import_module("comfy_extras.nodes_seedvr") - ) - - def _restore(): - for mod_name, parent_name, attr, prior_mod, prior_attr in priors: - if prior_mod is _SENTINEL: - sys.modules.pop(mod_name, None) - else: - sys.modules[mod_name] = prior_mod - parent = sys.modules.get(parent_name) - if parent is None: - continue - if prior_attr is _SENTINEL: - if hasattr(parent, attr): - delattr(parent, attr) - else: - setattr(parent, attr, prior_attr) - - return nodes_seedvr, _restore - - -class _Rope(nn.Module): - """Minimal RoPE stub exposing a `freqs` parameter.""" - def __init__(self): - super().__init__() - self.freqs = nn.Parameter(torch.zeros(4)) - - -class _Block(nn.Module): - """Minimal transformer block stub holding a `_Rope`.""" - def __init__(self): - super().__init__() - self.rope = _Rope() - - -class _DiffusionModel(nn.Module): - """Stub diffusion model with N blocks and pos/neg conditioning buffers.""" - def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32): - super().__init__() - self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) - pos = torch.zeros if zero_conditioning else torch.ones - self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype)) - self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype)) - - -class _ModelInner: - """Inner model wrapper exposing `.diffusion_model`.""" - def __init__(self, diffusion_model): - self.diffusion_model = diffusion_model - - -class _ModelPatcher: - """ModelPatcher stub exposing `.model._ModelInner`.""" - def __init__(self, diffusion_model): - self.model = _ModelInner(diffusion_model) - - -def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): - nodes_seedvr, restore = _import_nodes_seedvr_isolated() - try: - schema = nodes_seedvr.SeedVR2Conditioning.define_schema() - assert [input_item.id for input_item in schema.inputs] == [ - "model", - "vae_conditioning", - ] - assert schema.inputs[1].display_name == "latent" - assert [output.display_name for output in schema.outputs] == [ - "model", - "positive", - "negative", - "latent", - ] - finally: - restore() - - -def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): - nodes_seedvr, restore = _import_nodes_seedvr_isolated() - try: - diffusion_model = _DiffusionModel() - patcher = _ModelPatcher(diffusion_model) - samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2) - vae_conditioning = {"samples": samples} - - _, first_positive, first_negative, first_latent = ( - nodes_seedvr.SeedVR2Conditioning.execute( - patcher, - vae_conditioning, - ) - ) - _, second_positive, second_negative, second_latent = ( - nodes_seedvr.SeedVR2Conditioning.execute( - patcher, - vae_conditioning, - ) - ) - - expected_latent = samples.reshape(1, 6, 2, 2) - channel_last = samples.movedim(1, -1).contiguous() - expected_condition = torch.cat( - [ - channel_last, - torch.ones((*channel_last.shape[:-1], 1)), - ], - dim=-1, - ).movedim(-1, 1).reshape(1, 9, 2, 2) - - assert torch.equal(first_latent["samples"], expected_latent) - assert torch.equal(second_latent["samples"], expected_latent) - assert torch.equal( - first_positive[0][1]["condition"], - expected_condition, - ) - assert torch.equal( - second_positive[0][1]["condition"], - expected_condition, - ) - assert torch.equal( - first_negative[0][1]["condition"], - expected_condition, - ) - assert torch.equal( - second_negative[0][1]["condition"], - expected_condition, - ) - finally: - restore() - - -def test_seedvr2_conditioning_fails_loud_on_zero_buffers(): - nodes_seedvr, restore = _import_nodes_seedvr_isolated() - try: - diffusion_model = _DiffusionModel(zero_conditioning=True) - patcher = _ModelPatcher(diffusion_model) - vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} - - with pytest.raises(RuntimeError) as excinfo: - nodes_seedvr.SeedVR2Conditioning.execute( - patcher, vae_conditioning, - ) - - message = str(excinfo.value) - assert message.startswith( - nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX - ), ( - "Fail-loud message must use the standard " - "_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers " - f"can match it. Got: {message!r}" - ) - assert "positive_conditioning" in message - assert "negative_conditioning" in message - finally: - restore() diff --git a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py deleted file mode 100644 index f7d9a4f65..000000000 --- a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py +++ /dev/null @@ -1,55 +0,0 @@ -import importlib -import inspect -import sys -from unittest.mock import MagicMock, patch - -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - - -def test_seedvr_node_signature_matches_schema(): - mock_mm = MagicMock() - mock_mm.xformers_enabled.return_value = False - mock_mm.xformers_enabled_vae.return_value = False - mock_mm.sage_attention_enabled.return_value = False - mock_mm.flash_attention_enabled.return_value = False - - sentinel = object() - prior_cpu = cli_args.cpu - cli_args.cpu = True - prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel) - comfy_pkg = sys.modules.get("comfy") - prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel - - with patch.dict(sys.modules, {"comfy.model_management": mock_mm}): - if comfy_pkg is not None: - setattr(comfy_pkg, "model_management", mock_mm) - sys.modules.pop("comfy_extras.nodes_seedvr", None) - try: - nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr") - for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler): - schema_ids = [i.id for i in node_cls.define_schema().inputs] - exec_params = [ - p for p in inspect.signature(node_cls.execute).parameters.keys() - if p != "cls" - ] - assert schema_ids == exec_params, ( - f"{node_cls.__name__} schema/execute drift: " - f"schema_ids={schema_ids}, exec_params={exec_params}" - ) - finally: - cli_args.cpu = prior_cpu - if prior_module is sentinel: - sys.modules.pop("comfy_extras.nodes_seedvr", None) - else: - sys.modules["comfy_extras.nodes_seedvr"] = prior_module - if comfy_pkg is not None: - if prior_mm_attr is sentinel: - if hasattr(comfy_pkg, "model_management"): - delattr(comfy_pkg, "model_management") - else: - setattr(comfy_pkg, "model_management", prior_mm_attr) diff --git a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py deleted file mode 100644 index a27a8f8df..000000000 --- a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py +++ /dev/null @@ -1,57 +0,0 @@ -from unittest.mock import patch - -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -from comfy_extras import nodes_seedvr # noqa: E402 - - -def _schema_ids(items): - return [item.id for item in items] - - -def test_seedvr2_post_processing_schema(): - schema = nodes_seedvr.SeedVR2PostProcessing.define_schema() - - assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"] - assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"] - assert schema.inputs[2].default == "lab" - assert schema.outputs[0].get_io_type() == "IMAGE" - - -def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch): - decoded = torch.full((1, 3, 4, 4), 0.25) - reference = torch.full((1, 3, 4, 4), 0.75) - - def _lab(content, style): - raise torch.cuda.OutOfMemoryError("CUDA out of memory") - - monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) - monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) - monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None) - - with patch.object(nodes_seedvr, "lab_color_transfer", _lab): - try: - nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked( - decoded, reference, torch.device("cpu"), "lab", - ) - except RuntimeError as exc: - assert "color_correction_method=lab" in str(exc) - assert " method=lab" not in str(exc) - else: - raise AssertionError("expected RuntimeError for one-frame LAB OOM") - - -def test_seedvr2_post_processing_unknown_color_correction_method_raises(): - decoded = torch.zeros(1, 2, 4, 4, 3) - original = torch.zeros(1, 2, 4, 4, 3) - try: - nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus") - except ValueError as exc: - assert "color_correction_method" in str(exc) - else: - raise AssertionError("expected ValueError for unknown color_correction_method") diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index c63f69a0d..4e9350602 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -73,24 +73,6 @@ def _make_flux_schnell_comfyui_sd(): return sd -def _make_seedvr2_7b_separate_mm_sd(): - return { - "blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072), - } - - -def _make_seedvr2_7b_shared_mm_sd(): - return { - "blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1), - } - - -def _make_seedvr2_3b_shared_mm_sd(): - return { - "blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1), - } - - class TestModelDetection: """Verify that first-match model detection selects the correct model based on list ordering and unet_config specificity.""" @@ -143,48 +125,6 @@ class TestModelDetection: assert model_config is not None assert type(model_config).__name__ == "FluxSchnell" - def test_seedvr2_7b_separate_mm_detection_config(self): - sd = _make_seedvr2_7b_separate_mm_sd() - unet_config = detect_unet_config(sd, "") - - assert unet_config is not None - assert unet_config["image_model"] == "seedvr2" - assert unet_config["vid_dim"] == 3072 - assert unet_config["heads"] == 24 - assert unet_config["num_layers"] == 36 - assert unet_config["mm_layers"] == 36 - assert unet_config["mlp_type"] == "normal" - assert unet_config["qk_rope"] is True - assert unet_config["rope_type"] == "rope3d" - assert unet_config["rope_dim"] == 64 - - def test_seedvr2_7b_shared_mm_detection_config(self): - sd = _make_seedvr2_7b_shared_mm_sd() - unet_config = detect_unet_config(sd, "") - - assert unet_config is not None - assert unet_config["image_model"] == "seedvr2" - assert unet_config["vid_dim"] == 3072 - assert unet_config["heads"] == 24 - assert unet_config["num_layers"] == 36 - assert unet_config["mm_layers"] == 10 - assert unet_config["mlp_type"] == "swiglu" - assert unet_config["qk_rope"] is True - assert unet_config["rope_type"] == "rope3d" - assert unet_config["rope_dim"] == 64 - - def test_seedvr2_3b_shared_mm_detection_config(self): - sd = _make_seedvr2_3b_shared_mm_sd() - unet_config = detect_unet_config(sd, "") - - assert unet_config is not None - assert unet_config["image_model"] == "seedvr2" - assert unet_config["vid_dim"] == 2560 - assert unet_config["heads"] == 20 - assert unet_config["num_layers"] == 32 - assert unet_config["mlp_type"] == "swiglu" - assert unet_config["qk_rope"] is None - def test_unet_config_and_required_keys_combination_is_unique(self): """Each model in the registry must have a unique combination of ``unet_config`` and ``required_keys``. If two models share the same diff --git a/tests-unit/comfy_test/seedvr_vae_forward_test.py b/tests-unit/comfy_test/seedvr_vae_forward_test.py deleted file mode 100644 index f9dbd6890..000000000 --- a/tests-unit/comfy_test/seedvr_vae_forward_test.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must -honor the actual tensor/tuple return contract of ``encode()`` and -``decode_()`` and must NOT dereference diffusers-style ``.latent_dist`` -or ``.sample`` attributes on those returns. - -The pre-fix body raised ``AttributeError: 'Tensor' object has no -attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and -``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'`` -for ``mode == "decode"`` (the class only defines ``decode_`` with a -trailing underscore). The post-fix body unwraps the optional one-element -tuple shape that ``return_dict=False`` produces and returns the tensor -directly. - -Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses -the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and -overrides ``encode``/``decode_`` with known tensors so the contract can -be probed without loading any real VAE weights. -""" - -import torch -import torch.nn as nn - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402 - - -_LATENT_SHAPE = (1, 16, 2, 2, 2) -_DECODED_SHAPE = (1, 3, 5, 16, 16) -_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16) -_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2) - - -class _StubVAE(VideoAutoencoderKL): - def __init__(self): - nn.Module.__init__(self) - self._encode_out = torch.zeros(*_LATENT_SHAPE) - self._decode_out = torch.zeros(*_DECODED_SHAPE) - - def encode(self, x, return_dict=True): - return self._encode_out - - def decode_(self, z, return_dict=True): - return self._decode_out - - -def test_forward_encode_returns_tensor(): - vae = _StubVAE() - x = torch.zeros(*_INPUT_ENCODE_SHAPE) - result = vae.forward(x, mode="encode") - assert type(result) is torch.Tensor - assert result.shape == torch.Size(_LATENT_SHAPE) - - -def test_forward_decode_returns_tensor(): - vae = _StubVAE() - z = torch.zeros(*_INPUT_DECODE_SHAPE) - result = vae.forward(z, mode="decode") - assert type(result) is torch.Tensor - assert result.shape == torch.Size(_DECODED_SHAPE) - - -class _TupleReturningStubVAE(VideoAutoencoderKL): - """Stub variant whose ``encode``/``decode_`` return the - ``(tensor,)`` one-element tuple shape ``return_dict=False`` produces - in the parent class. Exercises the unwrap branch of - ``VideoAutoencoderKL.forward``. - """ - - def __init__(self): - nn.Module.__init__(self) - self._encode_tensor = torch.zeros(*_LATENT_SHAPE) - self._decode_tensor = torch.zeros(*_DECODED_SHAPE) - - def encode(self, x, return_dict=True): - return (self._encode_tensor,) - - def decode_(self, z, return_dict=True): - return (self._decode_tensor,) - - -def test_forward_all_unwraps_one_tuple_at_each_step(): - vae = _TupleReturningStubVAE() - x = torch.zeros(*_INPUT_ENCODE_SHAPE) - result = vae.forward(x, mode="all") - assert type(result) is torch.Tensor - assert result.shape == torch.Size(_DECODED_SHAPE) diff --git a/tests-unit/comfy_test/test_seedvr2_dtype.py b/tests-unit/comfy_test/test_seedvr2_dtype.py deleted file mode 100644 index e5d79a306..000000000 --- a/tests-unit/comfy_test/test_seedvr2_dtype.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.sd -import comfy.supported_models -import comfy.ldm.seedvr.model as seedvr_model - - -def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch): - bf16_device = object() - fp16_device = object() - - monkeypatch.setattr( - comfy.supported_models.comfy.model_management, - "should_use_bf16", - lambda device=None: device is bf16_device, - ) - - bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) - bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device) - assert bf16_config.manual_cast_dtype is torch.bfloat16 - - fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) - fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device) - assert fp16_config.manual_cast_dtype is None - - -def test_seedvr2_text_conditioning_accepts_cfg1_single_branch(): - context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2) - - txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0]) - - torch.testing.assert_close(txt, context.squeeze(0)) - torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device)) - - -def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer(): - estimate = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160)) - old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2 - - assert estimate == 101 * 960 * 1280 * 160 - assert estimate > 15 * 1024 ** 3 - assert estimate > old_estimate * 100 diff --git a/tests-unit/comfy_test/test_seedvr2_internals.py b/tests-unit/comfy_test/test_seedvr2_internals.py deleted file mode 100644 index 5b008ea6e..000000000 --- a/tests-unit/comfy_test/test_seedvr2_internals.py +++ /dev/null @@ -1,341 +0,0 @@ -"""Consolidated SeedVR2 internals regression tests. - -Sources (all merged verbatim, helper names disambiguated where colliding): - - * RoPE rewrite — NaMMRotaryEmbedding3d.forward must match the legacy - apply_rotary_emb wrapper oracle at fp32. - * GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare - memory_occupy against get_norm_limit(), not float('inf'). - * SeedVR2 variable-length attention split-loop contract. - -Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and -comfy.ldm.modules.attention transitively pull in comfy.model_management, -which probes torch.cuda.current_device() at import time unless args.cpu is -set first. -""" - -from __future__ import annotations - -from unittest.mock import patch - -import pytest -import torch - -from comfy.cli_args import args - -if not torch.cuda.is_available(): - args.cpu = True - -import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 -import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 -import comfy.ldm.modules.attention as attention # noqa: E402 -import comfy.ops as comfy_ops # noqa: E402 -from comfy.ldm.seedvr.model import ( # noqa: E402 - Cache, - NaMMRotaryEmbedding3d, -) -from comfy.ldm.seedvr.vae import ( # noqa: E402 - causal_norm_wrapper, - set_norm_limit, -) -from comfy.ldm.modules.attention import var_attention_optimized_split # noqa: E402 - - -# --------------------------------------------------------------------------- -# RoPE rewrite tests (test_seedvr_rope_rewrite.py) -# --------------------------------------------------------------------------- - -# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains -# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8. -_DIM = 192 -_HEADS = 4 -_VID_T, _VID_H, _VID_W = 2, 4, 4 -_TXT_L = 8 -_L_VID = _VID_T * _VID_H * _VID_W -_SEED = 0 - - -def _make_inputs(dtype=torch.float32, device="cpu"): - """Construct the 6 forward inputs + cache. Deterministic via local - Generator so global RNG state is not mutated. - """ - g = torch.Generator(device=device).manual_seed(_SEED) - vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) - vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) - txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) - txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) - vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device) - txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device) - cache = Cache(disable=True) - return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache - - -def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape): - """Reproduce the pre-rewrite ``get_freqs`` body verbatim against - ``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method, - unchanged by the rewrite). - """ - max_temporal = 0 - max_height = 0 - max_width = 0 - max_txt_len = 0 - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - max_temporal = max(max_temporal, l + f) - max_height = max(max_height, h) - max_width = max(max_width, w) - max_txt_len = max(max_txt_len, l) - with torch.amp.autocast(device_type="cuda", enabled=False): - vid_freqs_full = rope.get_axial_freqs( - min(max_temporal + 16, 1024), - min(max_height + 4, 128), - min(max_width + 4, 128), - ).float() - txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024)) - vid_freq_list, txt_freq_list = [], [] - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1)) - txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1)) - vid_freq_list.append(vid_freq) - txt_freq_list.append(txt_freq) - return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) - - -def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape, - txt_q, txt_k, txt_shape): - """Compute expected forward output via the unchanged - ``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the - oracle. The wrapper itself is out of scope for the rewrite (Shape B). - """ - vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape) - vid_freqs = vid_freqs.to(vid_q.device) - txt_freqs = txt_freqs.to(txt_q.device) - - from einops import rearrange - - vid_q = rearrange(vid_q, "L h d -> h L d") - vid_k = rearrange(vid_k, "L h d -> h L d") - vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) - vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) - vid_q_out = rearrange(vid_q_out, "h L d -> L h d") - vid_k_out = rearrange(vid_k_out, "h L d -> L h d") - - txt_q = rearrange(txt_q, "L h d -> h L d") - txt_k = rearrange(txt_k, "L h d -> h L d") - txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) - txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) - txt_q_out = rearrange(txt_q_out, "h L d -> L h d") - txt_k_out = rearrange(txt_k_out, "h L d -> L h d") - return vid_q_out, vid_k_out, txt_q_out, txt_k_out - - -def test_namm_forward_output_tensor_equal_against_legacy_oracle(): - rope = NaMMRotaryEmbedding3d(dim=_DIM) - vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs() - - expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward( - rope, - vid_q.clone(), vid_k.clone(), vid_shape, - txt_q.clone(), txt_k.clone(), txt_shape, - ) - - actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward( - vid_q.clone(), vid_k.clone(), vid_shape, - txt_q.clone(), txt_k.clone(), txt_shape, cache, - ) - - torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0, - msg="vid_q output diverges from wrapper oracle") - torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0, - msg="vid_k output diverges from wrapper oracle") - torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0, - msg="txt_q output diverges from wrapper oracle") - torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0, - msg="txt_k output diverges from wrapper oracle") - - -# --------------------------------------------------------------------------- -# GroupNorm limit tests (test_seedvr_groupnorm_limit.py) -# --------------------------------------------------------------------------- - -_NUM_CHANNELS = 8 -_NUM_GROUPS = 4 -_TENSOR_SHAPE = (1, 8, 2, 4, 4) - -_GROUPNORM_SUBCLASSES = [ - pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"), - pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"), -] - - -@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES) -def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls): - real_group_norm = vae_mod.F.group_norm - set_norm_limit(1e-9) - try: - gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS) - gn.eval() - - forward_hook_calls = [] - - def _hook(module, inputs, output): - forward_hook_calls.append(tuple(inputs[0].shape)) - - spy_calls = [] - - def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs): - spy_calls.append({"num_groups": int(num_groups_arg)}) - return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs) - - handle = gn.register_forward_hook(_hook) - try: - with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy): - out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE)) - finally: - handle.remove() - - full_calls = len(forward_hook_calls) - chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS) - - assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE - assert full_calls == 0, ( - f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}" - ) - assert chunked_calls > 0, ( - f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}" - ) - finally: - set_norm_limit(None) - - -# --------------------------------------------------------------------------- -# SeedVR2 var_attention split-loop tests -# --------------------------------------------------------------------------- - -def test_var_attention_registry_contains_always_available_entries(): - assert ( - attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_optimized_split"] - is attention.var_attention_optimized_split - ) - - -def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch): - dim = 8 - heads = 2 - head_dim = 4 - attn = seedvr_model.NaSwinAttention( - vid_dim=dim, - txt_dim=dim, - heads=heads, - head_dim=head_dim, - qk_bias=False, - qk_norm=seedvr_model.CustomRMSNorm, - qk_norm_eps=1e-6, - rope_type=None, - rope_dim=head_dim, - shared_weights=False, - window=(2, 1, 1), - window_method="720pwin_by_size_bysize", - version=True, - device="cpu", - dtype=torch.float32, - operations=comfy_ops.disable_weight_init, - ) - generator = torch.Generator(device="cpu").manual_seed(11) - vid = torch.randn(8, dim, generator=generator) - txt = torch.randn(3, dim, generator=generator) - vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long) - txt_shape = torch.tensor([[3]], dtype=torch.long) - calls = [] - - def fake_optimized_var_attention(**kwargs): - calls.append(kwargs) - return kwargs["q"] - - monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention) - - vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True)) - - assert tuple(vid_out.shape) == (8, dim) - assert tuple(txt_out.shape) == (3, dim) - assert len(calls) == 1 - call = calls[0] - assert tuple(call["q"].shape) == (14, heads, head_dim) - assert tuple(call["k"].shape) == (14, heads, head_dim) - assert tuple(call["v"].shape) == (14, heads, head_dim) - assert call["heads"] == heads - assert call["skip_reshape"] is True - assert call["skip_output_reshape"] is True - torch.testing.assert_close( - call["cu_seqlens_q"], - torch.tensor([0, 7, 14], dtype=torch.int32), - rtol=0, - atol=0, - ) - torch.testing.assert_close( - call["cu_seqlens_k"], - torch.tensor([0, 7, 14], dtype=torch.int32), - rtol=0, - atol=0, - ) - - -def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch): - heads = 2 - head_dim = 3 - q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim) - k = q + 100 - v = q + 200 - cu = torch.tensor([0, 2, 5], dtype=torch.int32) - calls = [] - - def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs): - calls.append( - { - "q_shape": tuple(q_arg.shape), - "k_shape": tuple(k_arg.shape), - "v_shape": tuple(v_arg.shape), - "heads": heads_arg, - "kwargs": kwargs, - } - ) - return q_arg + v_arg - - monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention) - - out = var_attention_optimized_split( - q, - k, - v, - heads, - cu, - cu, - skip_reshape=True, - skip_output_reshape=True, - ) - - assert tuple(out.shape) == (5, heads, head_dim) - assert len(calls) == 2 - assert calls[0]["q_shape"] == (1, heads, 2, head_dim) - assert calls[1]["q_shape"] == (1, heads, 3, head_dim) - assert all(call["heads"] == heads for call in calls) - assert all(call["kwargs"]["skip_reshape"] is True for call in calls) - assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls) - torch.testing.assert_close(out, q + v, rtol=0, atol=0) - - -def test_var_attention_optimized_split_rejects_bad_offsets(): - q = torch.randn(5, 2, 3) - cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32) - cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32) - - with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"): - var_attention_optimized_split( - q, - q, - q, - 2, - cu_bad, - cu_ok, - skip_reshape=True, - skip_output_reshape=True, - ) diff --git a/tests-unit/comfy_test/test_seedvr2_model.py b/tests-unit/comfy_test/test_seedvr2_model.py deleted file mode 100644 index f2b9bcbbe..000000000 --- a/tests-unit/comfy_test/test_seedvr2_model.py +++ /dev/null @@ -1,308 +0,0 @@ -"""Consolidated SeedVR2 model/graph/forward regression tests. - -Merged from: -- seedvr_model_test.py -- test_seedvr_7b_final_block_text_path.py -- test_seedvr_forward_no_device_cast.py -- test_seedvr_latent_format.py -- test_seedvr2_vae_graph_boundaries.py -""" - -from __future__ import annotations - -from unittest.mock import MagicMock - -import torch -from torch import nn - -from comfy.cli_args import args - -if not torch.cuda.is_available(): - args.cpu = True - -import comfy # noqa: E402 -import comfy.latent_formats # noqa: E402 -import comfy.ldm.seedvr.model # noqa: E402 -import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 -import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 -import comfy.model_management # noqa: E402 -import comfy.sample # noqa: E402 -import comfy.sd as sd_mod # noqa: E402 -import nodes as nodes_mod # noqa: E402 -from comfy.ldm.seedvr.model import NaDiT # noqa: E402 - - -# --------------------------------------------------------------------------- -# Helpers from seedvr_model_test.py -# --------------------------------------------------------------------------- - - -def _make_standin(positive_conditioning): - class _StandIn(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer( - "positive_conditioning", positive_conditioning - ) - - _resolve_text_conditioning = NaDiT._resolve_text_conditioning - - return _StandIn() - - -# --------------------------------------------------------------------------- -# Helpers from test_seedvr_7b_final_block_text_path.py -# --------------------------------------------------------------------------- - - -class _StubModule(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - - -def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]: - flags = [] - - class _Block(_StubModule): - def __init__(self, *args, **kwargs): - flags.append(kwargs["is_last_layer"]) - super().__init__() - - monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule) - monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule) - monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule) - monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block) - - seedvr_model.NaDiT( - norm_eps=1e-5, - qk_rope=None, - num_layers=4, - mlp_type="normal", - vid_dim=vid_dim, - txt_in_dim=txt_in_dim, - heads=24, - mm_layers=3, - ) - - return flags - - -# --------------------------------------------------------------------------- -# Helpers from test_seedvr_latent_format.py -# --------------------------------------------------------------------------- - - -class _Model: - def __init__(self, latent_format): - self._latent_format = latent_format - - def get_model_object(self, name): - assert name == "latent_format" - return self._latent_format - - -# --------------------------------------------------------------------------- -# Helpers from test_seedvr2_vae_graph_boundaries.py -# --------------------------------------------------------------------------- - - -class _Patcher: - def get_free_memory(self, device): - return 1024 * 1024 * 1024 - - -class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): - def __init__(self, encoded): - nn.Module.__init__(self) - self.encoded = encoded - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.seen = [] - - def encode(self, x): - self.seen.append(tuple(x.shape)) - return self.encoded.to(device=x.device, dtype=x.dtype) - - -class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): - def __init__(self): - nn.Module.__init__(self) - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.calls = [] - - def decode(self, z, seedvr2_tiling=None): - self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling}) - if z.ndim == 4: - b, tc, h, w = z.shape - t = tc // 16 - else: - b, _, t, h, w = z.shape - return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) - - -def _make_vae(wrapper): - vae = sd_mod.VAE.__new__(sd_mod.VAE) - vae.first_stage_model = wrapper - vae.device = torch.device("cpu") - vae.output_device = torch.device("cpu") - vae.vae_dtype = torch.float32 - vae.latent_channels = 16 - vae.latent_dim = 3 - vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8) - vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) - vae.output_channels = 3 - vae.disable_offload = True - vae.extra_1d_channel = None - vae.crop_input = False - vae.not_video = False - vae.patcher = _Patcher() - vae.process_input = lambda image: image - vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0) - vae.vae_output_dtype = lambda: torch.float32 - vae.memory_used_encode = lambda shape, dtype: 1 - vae.memory_used_decode = lambda shape, dtype: 1 - vae.throw_exception_if_invalid = lambda: None - vae.vae_encode_crop_pixels = lambda pixels: pixels - vae.spacial_compression_decode = lambda: 8 - vae.temporal_compression_decode = lambda: 4 - return vae - - -# --------------------------------------------------------------------------- -# Tests from seedvr_model_test.py -# --------------------------------------------------------------------------- - - -def test_missing_context_falls_back_to_positive_buffer(): - """AC: ``context is None`` falls back to the registered - ``positive_conditioning`` buffer and runs to completion — no - silent zero substitution, no raised exception. - """ - pos_buffer = torch.full((58, 5120), 7.0) - standin = _make_standin(pos_buffer) - txt, txt_shape = standin._resolve_text_conditioning(None) - assert txt.shape == (58, 5120) - assert (txt == 7.0).all(), ( - "fallback path must use the positive_conditioning buffer " - "verbatim, not a zero tensor" - ) - assert txt_shape.shape == (1, 1) - assert txt_shape[0, 0].item() == 58 - - -# --------------------------------------------------------------------------- -# Tests from test_seedvr_7b_final_block_text_path.py -# --------------------------------------------------------------------------- - - -def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch): - assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [ - False, - False, - False, - False, - ] - - -def test_seedvr2_7b_rope3d_matches_wrapper_oracle(): - rope = seedvr_model.get_na_rope("rope3d", dim=64) - generator = torch.Generator(device="cpu").manual_seed(0) - q = torch.randn(4, 2, 128, generator=generator) - k = torch.randn(4, 2, 128, generator=generator) - shape = torch.tensor([[1, 2, 2]], dtype=torch.long) - freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1) - - expected_q = seedvr_model._apply_seedvr2_rotary_emb( - freqs, - q.permute(1, 0, 2).float(), - ).to(q.dtype).permute(1, 0, 2) - expected_k = seedvr_model._apply_seedvr2_rotary_emb( - freqs, - k.permute(1, 0, 2).float(), - ).to(k.dtype).permute(1, 0, 2) - - actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True)) - - torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0) - torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0) - - -# --------------------------------------------------------------------------- -# Tests from test_seedvr_latent_format.py -# --------------------------------------------------------------------------- - - -def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion(): - latent_format = comfy.latent_formats.SeedVR2() - latent_image = torch.zeros(1, 1, 4, 5) - - fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image) - - assert latent_format.latent_channels == 16 - assert latent_format.latent_dimensions == 2 - assert fixed.shape == (1, 16, 4, 5) - - -# --------------------------------------------------------------------------- -# Tests from test_seedvr2_vae_graph_boundaries.py -# --------------------------------------------------------------------------- - - -def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch): - monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) - - encoded = torch.full((1, 16, 2, 4, 5), 2.0) - vae = _make_vae(_EncodeWrapper(encoded)) - pixels = torch.zeros(1, 5, 32, 40, 3) - - node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0] - node_latent = node_output["samples"] - assert set(node_output) == {"samples"} - assert tuple(node_latent.shape) == (1, 16, 2, 4, 5) - assert node_latent.dtype == torch.float32 - assert node_latent.stride()[-1] == 1 - assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152)) - - tiled = torch.full((1, 16, 2, 4, 5), 3.0) - monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled)) - tiled_output = nodes_mod.VAEEncodeTiled().encode( - vae, - pixels, - tile_size=512, - overlap=64, - temporal_size=16, - temporal_overlap=4, - )[0] - tiled_latent = tiled_output["samples"] - assert set(tiled_output) == {"samples"} - assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5) - assert tiled_latent.dtype == torch.float32 - assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152)) - - -def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch): - monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) - vae = _make_vae(_DecodeWrapper()) - - nodes_mod.VAEDecodeTiled().decode( - vae, - {"samples": torch.zeros(1, 16, 2, 4, 5)}, - tile_size=512, - overlap=64, - temporal_size=16, - temporal_overlap=4, - ) - - assert vae.first_stage_model.calls == [ - { - "shape": (1, 16, 2, 4, 5), - "seedvr2_tiling": { - "enable_tiling": True, - "tile_size": (512, 512), - "tile_overlap": (64, 64), - "temporal_size": 16, - "temporal_overlap": 4, - }, - } - ] diff --git a/tests-unit/comfy_test/test_seedvr2_vae_decode.py b/tests-unit/comfy_test/test_seedvr2_vae_decode.py deleted file mode 100644 index ea9f978f3..000000000 --- a/tests-unit/comfy_test/test_seedvr2_vae_decode.py +++ /dev/null @@ -1,91 +0,0 @@ -from unittest.mock import patch - -import pytest -import torch -import torch.nn as nn - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 -from comfy_extras import nodes_seedvr # noqa: E402 - - -def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper: - wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( - vae_mod.VideoAutoencoderKLWrapper - ) - nn.Module.__init__(wrapper) - return wrapper - - -def _fingerprint_decode_(self, z, return_dict=True): - b = int(z.shape[0]) - t = int(z.shape[2]) - h = int(z.shape[3]) - w = int(z.shape[4]) - out = torch.empty(b, 3, t, h * 8, w * 8) - for batch_idx in range(b): - out[batch_idx].fill_(float(batch_idx + 1)) - return out - - -def _decode_with_patches(wrapper, z): - with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_): - return wrapper.decode(z) - - -def test_decode_b2_t3_multi_frame_batch_unchanged(): - wrapper = _make_wrapper() - - out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2)) - - assert tuple(out.shape) == (2, 3, 3, 16, 16) - - -class _Wrapper(vae_mod.VideoAutoencoderKLWrapper): - def __init__(self): - nn.Module.__init__(self) - self.calls = [] - - def parameters(self): - return iter([torch.nn.Parameter(torch.zeros(()))]) - -def _decode_stub(self, latent): - self.calls.append(tuple(latent.shape)) - return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8) - - -def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state(): - wrapper = _Wrapper() - - with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub): - out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5)) - - assert tuple(out.shape) == (1, 3, 2, 32, 40) - assert wrapper.calls == [(1, 16, 2, 4, 5)] - - -def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents(): - wrapper = _Wrapper() - - with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"): - wrapper.decode(torch.zeros(1, 16, 4)) - - -def _t_padded(t_in: int) -> int: - if t_in == 1: - return 1 - if t_in <= 4: - return 5 - if (t_in - 1) % 4 == 0: - return t_in - return t_in + (4 - ((t_in - 1) % 4)) - - -@pytest.mark.parametrize("t_in", [1, 5, 9]) -def test_t_padded_matches_cut_videos(t_in): - dummy = torch.zeros(1, t_in, 1, 1, 1) - assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in) diff --git a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py deleted file mode 100644 index 40079bbe2..000000000 --- a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py +++ /dev/null @@ -1,347 +0,0 @@ -from contextlib import ExitStack -from unittest.mock import MagicMock, patch - -import torch -import torch.nn as nn - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 -import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 -import comfy.sd as sd_mod # noqa: E402 -from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 - - -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_decode_latent_min_size_override.py -# --------------------------------------------------------------------------- - - -def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): - from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae - - class StubVAEModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.slicing_latent_min_size = 2 - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.device = torch.device("cpu") - self.use_slicing = True - self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) - self.decode_min_sizes = [] - self.memory_states = [] - - def decode_(self, t_chunk): - self.decode_min_sizes.append(self.slicing_latent_min_size) - return VideoAutoencoderKL.slicing_decode(self, t_chunk) - - def _decode(self, z, memory_state=MemoryState.DISABLED): - self.memory_states.append(memory_state) - b, c, d, h, w = z.shape - return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) - - vae = StubVAEModel() - z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32) - - tiled_vae( - z, - vae, - tile_size=(64, 64), - tile_overlap=(0, 0), - temporal_size=0, - temporal_overlap=0, - encode=False, - ) - - assert vae.decode_min_sizes == [5] - assert vae.memory_states == [MemoryState.DISABLED] - assert vae.slicing_latent_min_size == 2 - - -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_encode_runt_slice_override.py -# --------------------------------------------------------------------------- - - -def test_zero_temporal_size_preserves_min_size_when_encode_raises(): - from comfy.ldm.seedvr.vae import tiled_vae - - class RaisingVAEModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.slicing_sample_min_size = 4 - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.device = torch.device("cpu") - self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) - - def encode(self, t_chunk): - raise RuntimeError("simulated encode failure") - - vae = RaisingVAEModel() - x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32) - - raised = False - try: - tiled_vae( - x, - vae, - tile_size=(64, 64), - tile_overlap=(0, 0), - temporal_size=0, - temporal_overlap=0, - encode=True, - ) - except RuntimeError as exc: - if "simulated encode failure" not in str(exc): - raise - raised = True - - assert raised - assert vae.slicing_sample_min_size == 4 - - -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_temporal_slicing.py -# --------------------------------------------------------------------------- - - -class _SlicingDecodeVAE(nn.Module): - def __init__(self, slicing_latent_min_size): - super().__init__() - self.slicing_latent_min_size = slicing_latent_min_size - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.device = torch.device("cpu") - self.use_slicing = True - self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32)) - self.decode_min_sizes = [] - self.memory_states = [] - - def decode_(self, z): - self.decode_min_sizes.append(self.slicing_latent_min_size) - return vae_mod.VideoAutoencoderKL.slicing_decode(self, z) - - def _decode(self, z, memory_state=MemoryState.DISABLED): - self.memory_states.append(memory_state) - x = z[:, :1].repeat( - 1, - 3, - 1, - self.spatial_downsample_factor, - self.spatial_downsample_factor, - ) - return x - - -def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): - vae = _SlicingDecodeVAE(slicing_latent_min_size=2) - z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8) - - tiled_vae( - z, - vae, - tile_size=(64, 64), - tile_overlap=(0, 0), - temporal_size=12, - temporal_overlap=4, - encode=False, - ) - - assert vae.decode_min_sizes == [2] - assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE] - assert vae.slicing_latent_min_size == 2 - - wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( - vae_mod.VideoAutoencoderKLWrapper - ) - nn.Module.__init__(wrapper) - seedvr2_tiling = { - "enable_tiling": True, - "tile_size": (64, 64), - "tile_overlap": (0, 0), - "temporal_size": 8, - "temporal_overlap": 7, - } - - captured = {} - - def _fake_tiled_vae(latent, model, **kwargs): - captured.update(kwargs) - return torch.zeros(1, 3, 1, 16, 16) - - with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae): - wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling) - - assert captured["temporal_overlap"] == 7 - - -# --------------------------------------------------------------------------- -# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py -# --------------------------------------------------------------------------- - - -def _force_oom(*a, **k): - raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") - - -def _make_vae(first_stage_model, latent_channels, latent_dim): - vae = sd_mod.VAE.__new__(sd_mod.VAE) - vae.first_stage_model = first_stage_model - vae.patcher = MagicMock() - vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) - vae.device = vae.output_device = torch.device("cpu") - vae.vae_dtype = torch.float32 - vae.disable_offload = True - vae.extra_1d_channel = None - vae.upscale_ratio = vae.downscale_ratio = 8 - vae.upscale_index_formula = vae.downscale_index_formula = None - vae.output_channels = 3 - vae.latent_channels = latent_channels - vae.latent_dim = latent_dim - vae.vae_output_dtype = lambda: torch.float32 - vae.spacial_compression_decode = lambda: 8 - vae.process_input = lambda x: x - vae.process_output = lambda x: x - vae.throw_exception_if_invalid = lambda: None - vae.memory_used_decode = lambda *a, **k: 1 - return vae - - -def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode): - mm = sd_mod.model_management - with ExitStack() as stack: - stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None)) - stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None)) - stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None)) - stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call)) - stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call)) - if patch_wrapper_decode: - stack.enter_context(patch.object( - seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode", - side_effect=_force_oom)) - vae.decode(samples) - - -def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2(): - wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( - seedvr_vae_mod.VideoAutoencoderKLWrapper) - vae = _make_vae(wrapper, latent_channels=16, latent_dim=3) - seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) - generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) - _dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True) - assert seedvr2_call.call_count == 1 - assert generic_call.call_count == 0 - - -def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled(): - first_stage = MagicMock() - first_stage.decode = MagicMock(side_effect=_force_oom) - vae = _make_vae(first_stage, latent_channels=4, latent_dim=2) - seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) - generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) - _dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False) - assert generic_call.call_count == 1 - assert seedvr2_call.call_count == 0 - - -# --------------------------------------------------------------------------- -# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py -# --------------------------------------------------------------------------- - - -def _populate_common_vae_attrs_fallback(vae): - vae.patcher = MagicMock() - vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) - vae.device = torch.device("cpu") - vae.output_device = torch.device("cpu") - vae.vae_dtype = torch.float32 - vae.disable_offload = True - vae.extra_1d_channel = None - vae.upscale_ratio = 8 - vae.upscale_index_formula = None - vae.output_channels = 3 - vae.latent_channels = 16 - vae.latent_dim = 3 - vae.downscale_ratio = 8 - vae.downscale_index_formula = None - vae.not_video = False - vae.crop_input = False - vae.pad_channel_value = None - - vae.vae_output_dtype = lambda: torch.float32 - vae.spacial_compression_encode = lambda: 8 - vae.process_input = lambda x: x - vae.process_output = lambda x: x - vae.throw_exception_if_invalid = lambda: None - vae.memory_used_encode = lambda *a, **k: 1 - - -def _make_seedvr2_vae_fallback(): - vae = sd_mod.VAE.__new__(sd_mod.VAE) - wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( - seedvr_vae_mod.VideoAutoencoderKLWrapper - ) - vae.first_stage_model = wrapper - _populate_common_vae_attrs_fallback(vae) - return vae - - -def _make_non_seedvr2_vae_fallback(): - vae = sd_mod.VAE.__new__(sd_mod.VAE) - vae.first_stage_model = MagicMock() - _populate_common_vae_attrs_fallback(vae) - return vae - - -def _force_regular_encode_oom(*args, **kwargs): - raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") - - -def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom(): - vae = _make_seedvr2_vae_fallback() - pixel_samples = torch.zeros((1, 8, 64, 64, 3)) - - seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) - generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) - - with patch.object(sd_mod.model_management, "raise_non_oom", - lambda e: None), \ - patch.object(sd_mod.model_management, "load_models_gpu", - lambda *a, **k: None), \ - patch.object(sd_mod.model_management, "soft_empty_cache", - lambda: None), \ - patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode", - side_effect=_force_regular_encode_oom), \ - patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call, - create=True), \ - patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): - vae.encode(pixel_samples) - - assert seedvr2_call.call_count == 1, ( - f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D " - f"input under OOM fallback; got {seedvr2_call.call_count} calls." - ) - assert generic_call.call_count == 0, ( - f"encode_tiled_3d must NOT be called for a SeedVR2 input; got " - f"{generic_call.call_count} calls." - ) - - -def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete(): - vae = _make_non_seedvr2_vae_fallback() - vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8) - vae.upscale_ratio = (lambda a: a * 4, 8, 8) - generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) - pixel_samples = torch.zeros((1, 8, 64, 64, 3)) - - with patch.object(sd_mod.model_management, "load_models_gpu", - lambda *a, **k: None), \ - patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): - vae.encode_tiled(pixel_samples) - - assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64) diff --git a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py deleted file mode 100644 index 05291989e..000000000 --- a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``.""" - -from unittest.mock import patch - -import pytest -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.sample # noqa: E402 -import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402 -from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402 - -_LAT_C = 16 -_COND_C = 17 - - -def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8): - """Build minimal SeedVR2-shaped sampling inputs.""" - samples_5d = torch.arange( - B * _LAT_C * T * H * W, dtype=torch.float32 - ).reshape(B, _LAT_C, T, H, W) - samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous() - - cond_5d = torch.arange( - B * _COND_C * T * H * W, dtype=torch.float32 - ).reshape(B, _COND_C, T, H, W) + 10000.0 - cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous() - - text_pos = torch.zeros(1, 4, 32) - text_neg = torch.zeros(1, 4, 32) - positive = [[text_pos, {"condition": cond.clone()}]] - negative = [[text_neg, {"condition": cond.clone()}]] - latent_image = {"samples": samples} - return latent_image, positive, negative, samples_5d, cond_5d - - -def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None): - return latent_image - - -def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None): - """Return a tensor whose values encode ``(seed, position)``.""" - base = torch.arange( - latent_image.numel(), dtype=torch.float32 - ).reshape(latent_image.shape) - return base + float(seed) * 1e6 - - -def test_progressive_sampler_schema_exposes_manual_default_auto_chunking(): - schema = SeedVR2ProgressiveSampler.define_schema() - inputs = {item.id: item for item in schema.inputs} - - assert inputs["chunking_mode"].options == ["manual", "auto"] - assert inputs["chunking_mode"].default == "manual" - - -def test_auto_chunking_walks_two_three_four_chunk_ladder(): - """Auto mode must walk 2-, 3-, then 4-chunk geometries on OOM.""" - latent, pos, neg, _, _ = _make_inputs(T=17) - calls = [] - - def _oom_until_four_chunks(model, noise, steps, cfg, sampler_name, - scheduler, positive, negative, - latent_image, denoise=1.0, - noise_mask=None, seed=None): - calls.append(tuple(latent_image.shape)) - if latent_image.shape[1] > _LAT_C * 5: - raise torch.cuda.OutOfMemoryError("chunk too large") - return latent_image.clone() - - with patch.object(comfy.sample, "sample", - side_effect=_oom_until_four_chunks), \ - patch.object(comfy.sample, "fix_empty_latent_channels", - side_effect=_identity_fix_empty), \ - patch.object(comfy.sample, "prepare_noise", - side_effect=_fingerprinted_prepare_noise), \ - patch.object(nodes_seedvr_mod.comfy.model_management, - "soft_empty_cache") as soft_empty: - out = SeedVR2ProgressiveSampler.execute( - model=None, seed=0, steps=2, cfg=1.0, - sampler_name="euler", scheduler="simple", - positive=pos, negative=neg, latent=latent, - denoise=1.0, frames_per_chunk=65, temporal_overlap=0, - chunking_mode="auto", - ) - - assert calls[:4] == [ - (1, _LAT_C * 17, 8, 8), - (1, _LAT_C * 9, 8, 8), - (1, _LAT_C * 6, 8, 8), - (1, _LAT_C * 5, 8, 8), - ] - assert torch.equal(out.result[0]["samples"], latent["samples"]) - assert soft_empty.call_count == 3 - - -@pytest.mark.parametrize("bad_chunk", [0, -1, 2]) -def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk): - """``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation.""" - latent, pos, neg, _, _ = _make_inputs(T=5) - - sampler_called = {"n": 0} - - def _should_not_be_called(*args, **kwargs): - sampler_called["n"] += 1 - return torch.zeros(1) - - with patch.object(comfy.sample, "sample", - side_effect=_should_not_be_called), \ - patch.object(comfy.sample, "fix_empty_latent_channels", - side_effect=_identity_fix_empty), \ - patch.object(comfy.sample, "prepare_noise", - side_effect=_fingerprinted_prepare_noise): - with pytest.raises(ValueError) as excinfo: - SeedVR2ProgressiveSampler.execute( - model=None, seed=0, steps=2, cfg=1.0, - sampler_name="euler", scheduler="simple", - positive=pos, negative=neg, latent=latent, - denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0, - ) - assert str(bad_chunk) in str(excinfo.value) - assert sampler_called["n"] == 0 From cb9f6394160808f7d25163f6cc2ea300c6841ef9 Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Tue, 9 Jun 2026 12:19:13 +0900 Subject: [PATCH 3/8] chore(openapi): sync shared API contract from cloud@5273c30 (#14266) --- openapi.yaml | 229 +++++++++++++++++---------------------------------- 1 file changed, 76 insertions(+), 153 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index b7e21245f..2510f97d0 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -3,11 +3,6 @@ components: Asset: description: Represents a user-owned asset (image, video, or other generated output). properties: - asset_hash: - deprecated: true - description: 'Deprecated: use hash instead. Blake3 hash of the asset content.' - pattern: ^blake3:[a-f0-9]{64}$ - type: string created_at: description: Timestamp when the asset was created format: date-time @@ -16,8 +11,12 @@ components: description: Display name of the asset. Mirrors name for backwards compatibility. nullable: true type: string + file_path: + description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors") + nullable: true + type: string hash: - description: Blake3 hash of the asset content. Preferred over asset_hash. + description: Blake3 hash of the asset content. pattern: ^blake3:[a-f0-9]{64}$ type: string id: @@ -139,17 +138,16 @@ components: AssetUpdated: description: Response returned when an existing asset is successfully updated. properties: - asset_hash: - deprecated: true - description: 'Deprecated: use hash instead. Blake3 hash of the asset content.' - pattern: ^blake3:[a-f0-9]{64}$ - type: string display_name: description: Display name of the asset. Mirrors name for backwards compatibility. nullable: true type: string + file_path: + description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors") + nullable: true + type: string hash: - description: Blake3 hash of the asset content. Preferred over asset_hash. + description: Blake3 hash of the asset content. pattern: ^blake3:[a-f0-9]{64}$ type: string id: @@ -828,7 +826,11 @@ components: type: string type: object PaginationInfo: - description: Offset/limit-based pagination metadata included in list responses. + description: | + Pagination metadata included in list responses. Supports both legacy + offset/limit pagination and cursor-based pagination. When cursor-based + pagination is used, `next_cursor` is the primary pagination token and + `offset`/`total` may be zero. properties: has_more: description: Whether more items are available beyond this page @@ -837,12 +839,19 @@ components: description: Items per page minimum: 1 type: integer + next_cursor: + description: | + Opaque cursor for the next page. Pass this value as the `after` + query parameter on the next request. Empty or absent when there + are no more results. + type: string offset: - description: Current offset (0-based) + deprecated: true + description: 'Current offset (0-based). Deprecated: use cursor-based pagination.' minimum: 0 type: integer total: - description: Total number of items matching filters + description: Total number of items matching filters (may be 0 when using cursor pagination) minimum: 0 type: integer required: @@ -1518,17 +1527,11 @@ paths: schema: default: true type: boolean - - description: Filter assets by exact content hash. Preferred over asset_hash. + - description: Filter assets by exact content hash. in: query name: hash schema: type: string - - deprecated: true - description: 'Deprecated: use hash instead. Filter assets by exact content hash.' - in: query - name: asset_hash - schema: - type: string - description: | Opaque cursor for keyset pagination. Pass the `next_cursor` value from the previous response to fetch the next page. When provided, @@ -1571,42 +1574,12 @@ paths: - file post: description: | - Uploads a new asset to the system with associated metadata. - Supports two upload methods: - 1. Direct file upload (multipart/form-data) - 2. URL-based upload (application/json with source: "url") + Creates a new asset from a direct file upload (multipart/form-data) with associated metadata. If an asset with the same hash already exists, returns the existing asset. - operationId: uploadAsset + operationId: createAsset requestBody: content: - application/json: - schema: - properties: - name: - description: Display name for the asset (used to determine file extension) - type: string - preview_id: - description: Optional preview asset ID - format: uuid - type: string - tags: - description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order. - items: - type: string - type: array - url: - description: HTTP/HTTPS URL to download the asset from - format: uri - type: string - user_metadata: - additionalProperties: true - description: Custom metadata to store with the asset - type: object - required: - - url - - name - type: object multipart/form-data: schema: properties: @@ -1614,6 +1587,10 @@ paths: description: The asset file to upload format: binary type: string + hash: + description: Content hash of the file. + pattern: ^(blake3|sha256):[a-f0-9]{64}$ + type: string id: description: Optional asset ID for idempotent creation. If provided and asset exists, returns existing asset. format: uuid @@ -1629,10 +1606,8 @@ paths: format: uuid type: string tags: - description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order. - items: - type: string - type: array + description: JSON-encoded array of freeform tag strings, e.g. '["models","checkpoint"]'. Common types include "models", "input", "output", and "temp", but any tag can be used in any order. + type: string user_metadata: description: Custom JSON metadata as a string type: string @@ -1641,36 +1616,32 @@ paths: type: object required: true responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/AssetCreated' + description: | + Asset already existed for this user (deduplicated by content hash); the + existing asset is returned with created_new=false. "201": content: application/json: schema: $ref: '#/components/schemas/AssetCreated' - description: Asset created successfully + description: Asset created successfully (created_new=true) "400": content: application/json: schema: $ref: '#/components/schemas/ErrorResponse' - description: Invalid request (bad file, invalid URL, invalid content type, etc.) + description: Invalid request (bad file, invalid content type, etc.) "401": content: application/json: schema: $ref: '#/components/schemas/ErrorResponse' description: Unauthorized - "403": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Source URL requires authentication or access denied - "404": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Source URL not found "413": content: application/json: @@ -1683,19 +1654,13 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' description: Unsupported media type - "422": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Download failed due to network error or timeout "500": content: application/json: schema: $ref: '#/components/schemas/ErrorResponse' description: Internal server error - summary: Upload a new asset + summary: Create a new asset tags: - file /api/assets/{id}: @@ -1730,7 +1695,7 @@ paths: application/json: schema: $ref: '#/components/schemas/ErrorResponse' - description: Asset cannot be deleted because it is referenced by another resource (e.g., workflow version) + description: 'Asset cannot be deleted because it is referenced by another resource, e.g. a workflow version (error code: ASSET_IN_USE)' "500": content: application/json: @@ -1783,7 +1748,7 @@ paths: description: | Updates an asset's metadata. At least one field must be provided. Only name, mime_type, preview_id, and user_metadata can be updated. - For tag management, use the dedicated PUT /api/assets/{id}/tags endpoint. + For tag management, use POST (add) and DELETE (remove) /api/assets/{id}/tags. operationId: updateAsset parameters: - description: Asset ID @@ -1982,76 +1947,6 @@ paths: summary: Add tags to asset tags: - file - put: - description: Adds and removes tags from an asset in a single operation - operationId: updateAssetTags - parameters: - - description: Asset ID - in: path - name: id - required: true - schema: - format: uuid - type: string - requestBody: - content: - application/json: - schema: - description: At least one of add or remove must contain items. Empty arrays are allowed when the other array has items. - minProperties: 1 - properties: - add: - description: Tags to add to the asset. Can be empty if remove has items. - items: - type: string - type: array - remove: - description: Tags to remove from the asset. Can be empty if add has items. - items: - type: string - type: array - type: object - required: true - responses: - "200": - content: - application/json: - schema: - $ref: '#/components/schemas/TagsModificationResponse' - description: Tags updated successfully - "400": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Invalid request - "401": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Unauthorized - "404": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Asset not found - "422": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Reserved tag validation error - "500": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Internal server error - summary: Update asset tags - tags: - - file /api/assets/from-hash: post: description: | @@ -2090,12 +1985,20 @@ paths: type: object required: true responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/AssetCreated' + description: | + Asset reference already existed for this user (deduplicated by content + hash); the existing asset is returned with created_new=false. "201": content: application/json: schema: $ref: '#/components/schemas/AssetCreated' - description: Asset reference created successfully + description: Asset reference created successfully (created_new=true) "400": content: application/json: @@ -2887,7 +2790,21 @@ paths: - asc - desc type: string - - description: Pagination offset (0-based) + - description: | + Opaque cursor for keyset pagination. Pass the `next_cursor` value + from a previous response to fetch the next page. + Cursor pagination is supported only when `sort_by=create_time` + (default). If `sort_by=execution_time`, `after` is ignored and + offset/limit pagination is used. + Cursors are opaque base64url payloads — clients should treat them + as strings and not parse the contents. + example: eyJzIjoiY3JlYXRlX3RpbWUiLCJ2IjoiMTcxNjIwMDAwMDAwMDAwMCIsImlkIjoiYTFiMmMzZDQtZTVmNi03YTg5LWIwYzEtZDJlM2Y0YTViNmM3In0 + in: query + name: after + schema: + type: string + - deprecated: true + description: 'Pagination offset (0-based). Deprecated: prefer cursor-based pagination via `after`.' in: query name: offset schema: @@ -2909,6 +2826,12 @@ paths: schema: $ref: '#/components/schemas/JobsListResponse' description: Success - Jobs retrieved + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Bad request (e.g. malformed pagination cursor). "401": content: application/json: From f89999289abe06c638e15d1895e3c7805bd486b1 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Mon, 8 Jun 2026 20:55:49 -0700 Subject: [PATCH 4/8] fix: Add back apply_rotary_emb for Qwen Image (#14364) --- comfy/ldm/qwen_image/model.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 3462d8108..e49886dd9 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -51,6 +51,18 @@ class FeedForward(nn.Module): return hidden_states +# Addin this back because Nunchaku custom nodes rely on it, see comment here: +# https://github.com/Comfy-Org/ComfyUI/pull/14178#issuecomment-4640475161 +# TODO: Eventually remove this once we natively support SVDQuants +def apply_rotary_emb(x, freqs_cis): + if x.shape[1] == 0: + return x + + t_ = x.reshape(*x.shape[:-1], -1, 1, 2) + t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] + return t_out.reshape(*x.shape) + + class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None): super().__init__() From 8ed7f458d055b565d063343bf94dab99f10f649a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 9 Jun 2026 16:11:05 +0300 Subject: [PATCH 5/8] Allow custom templates with Ideogram4 TE (#14374) --- comfy/text_encoders/ideogram4.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/ideogram4.py b/comfy/text_encoders/ideogram4.py index 55e655d67..84243772d 100644 --- a/comfy/text_encoders/ideogram4.py +++ b/comfy/text_encoders/ideogram4.py @@ -32,7 +32,9 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer): self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): - if llama_template is None: + if text.startswith('<|im_start|>'): + llama_text = text + elif llama_template is None: llama_text = self.llama_template.format(text) else: llama_text = llama_template.format(text) From 1639dc7a7041eaaf7ad96f8c7ea2894be01a7d28 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 9 Jun 2026 23:55:00 +1000 Subject: [PATCH 6/8] main/server: Add --debug-hang (#14371) Add an option to debug a hang with ctrl-C, dumping the backtraces to see where its stuck or slow. --- comfy/cli_args.py | 2 ++ main.py | 15 ++++++++++++++- server.py | 9 +++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a4cabcc65..cba0dfa34 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -166,6 +166,8 @@ class PerformanceFeature(enum.Enum): parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) +parser.add_argument("--debug-hang", action="store_true", help="Enable stack trace dumps on Ctrl-C for debugging hangs.") + parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.") parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") diff --git a/main.py b/main.py index 239a52013..7fcc8e97d 100644 --- a/main.py +++ b/main.py @@ -26,6 +26,7 @@ import utils.extra_config from utils.mime_types import init_mime_types import faulthandler import logging +import signal import sys from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context @@ -37,7 +38,19 @@ if __name__ == "__main__": os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' -faulthandler.enable(file=sys.stderr, all_threads=False) +faulthandler.enable(file=sys.stderr, all_threads=args.debug_hang) +if __name__ == "__main__" and args.debug_hang: + dumping_traceback = False + + def dump_traceback_on_sigint(signum, frame): + global dumping_traceback + if dumping_traceback: + raise KeyboardInterrupt + dumping_traceback = True + faulthandler.dump_traceback(file=sys.stderr, all_threads=True) + raise KeyboardInterrupt + + signal.signal(signal.SIGINT, dump_traceback_on_sigint) import comfy_aimdo.control diff --git a/server.py b/server.py index 268441bd1..a85c1e591 100644 --- a/server.py +++ b/server.py @@ -1253,6 +1253,15 @@ class PromptServer(): if verbose: logging.info("Starting server\n") + if args.debug_hang: + logging.info( + f"{'-' * 80}\n" + "ComfyUI has been started in debug-hang mode. Run your workflow as normal up to\n" + "the point of the hang or freeze, then use ctrl-C in the cmd or controlling\n" + "terminal to dump the python backtraces for debugging. Please attach the extra\n" + "debug info to your bug report.\n" + f"{'-' * 80}" + ) for addr in addresses: address = addr[0] port = addr[1] From 07c53f8f0fa6b014a46756eaa5a07fa9e411ccad Mon Sep 17 00:00:00 2001 From: kelseyee <971704395@qq.com> Date: Tue, 9 Jun 2026 21:57:58 +0800 Subject: [PATCH 7/8] Add LoRA key mapping for LTXV/LTXAV models (#14349) --- comfy/lora.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index 4e0ea29e0..2c8d0f0bf 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -357,6 +357,12 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")] key_map["transformer.{}".format(key_lora)] = k + if isinstance(model, (comfy.model_base.LTXV, comfy.model_base.LTXAV)): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k + return key_map From 184009c2f60db7b2e7dc4a80c28f9bc6029408d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 9 Jun 2026 18:24:09 +0300 Subject: [PATCH 8/8] feat: Add model support for SCAIL-2 (#14373) * initial SCAIL2 support --- comfy/ldm/wan/model.py | 57 ++++++- comfy/model_base.py | 74 +++++++++ comfy/model_detection.py | 2 + comfy/supported_models.py | 12 ++ comfy_extras/nodes_scail.py | 321 ++++++++++++++++++++++++++++++++++++ comfy_extras/nodes_wan.py | 58 ------- nodes.py | 1 + 7 files changed, 462 insertions(+), 63 deletions(-) create mode 100644 comfy_extras/nodes_scail.py diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 70dfe7b16..9178b3344 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1631,13 +1631,15 @@ class SCAILWanModel(WanModel): self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) - def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs): + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs): if reference_latent is not None: x = torch.cat((reference_latent, x), dim=2) # embeddings x = self.patch_embedding(x.float()).to(x.dtype) + if ref_mask_latents is not None: # SCAIL-2 additive mask stream + x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype) grid_sizes = x.shape[2:] transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) @@ -1645,6 +1647,8 @@ class SCAILWanModel(WanModel): scail_pose_seq_len = 0 if pose_latents is not None: scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype) + if sam_latents is not None: # SCAIL-2 additive mask stream + scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype) scail_x = scail_x.flatten(2).transpose(1, 2) scail_pose_seq_len = scail_x.shape[1] x = torch.cat([x, scail_x], dim=1) @@ -1695,7 +1699,36 @@ class SCAILWanModel(WanModel): return x - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}): + # ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode, + # which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset. + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}): + if ref_mask_flag is not None and not bool(ref_mask_flag): + REF_ROPE_H = 120.0 + POSE_ROPE_W = 120.0 + + ref_t_patches = 0 + if reference_latent is not None: + ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] + main_t_patches = t - ref_t_patches + + parts = [] + if ref_t_patches > 0: + ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}} + parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf)) + if main_t_patches > 0: + parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options)) + + if pose_latents is not None: + F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] + h_scale = h / H_pose + w_scale = w / W_pose + h_shift = (h_scale - 1) / 2 + w_shift = (w_scale - 1) / 2 + pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}} + parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf)) + + return torch.cat(parts, dim=1) + main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) if pose_latents is None: @@ -1719,12 +1752,16 @@ class SCAILWanModel(WanModel): return torch.cat([main_freqs, pose_freqs], dim=1) - def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs): + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) if pose_latents is not None: pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size) + if ref_mask_latents is not None: # SCAIL-2 + ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size) + if sam_latents is not None: # SCAIL-2 + sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size) t_len = t if time_dim_concat is not None: @@ -1737,5 +1774,15 @@ class SCAILWanModel(WanModel): reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size) t_len += reference_latent.shape[2] - freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] + ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2 + + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag) + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w] + + +class SCAIL2WanModel(SCAILWanModel): + """SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream.""" + + def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs): + super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs) + self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) diff --git a/comfy/model_base.py b/comfy/model_base.py index 042804771..d212a7c2a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1754,6 +1754,80 @@ class WAN21_SCAIL(WAN21): return out +class WAN21_SCAIL2(WAN21_SCAIL): + """SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream.""" + + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAIL2WanModel) + self.memory_usage_factor_conds = ("reference_latent", "pose_latents", "ref_mask_latents", "sam_latents") + self.memory_usage_shape_process = { + "pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]], + "sam_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]], + } + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + driving_mask_28ch = kwargs.get("driving_mask_28ch", None) + if driving_mask_28ch is not None: + out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous()) + + ref_mask_28ch = kwargs.get("ref_mask_28ch", None) + if ref_mask_28ch is not None: + out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous()) + + ref_mask_flag = kwargs.get("ref_mask_flag", None) + if ref_mask_flag is not None: + out['ref_mask_flag'] = comfy.conds.CONDConstant(ref_mask_flag) + + return out + + def extra_conds_shapes(self, **kwargs): + out = super().extra_conds_shapes(**kwargs) + driving_mask_28ch = kwargs.get("driving_mask_28ch", None) + if driving_mask_28ch is not None: + s = driving_mask_28ch.shape + out['sam_latents'] = [s[0], 28, s[1], s[3], s[4]] + ref_mask_28ch = kwargs.get("ref_mask_28ch", None) + if ref_mask_28ch is not None: + s = ref_mask_28ch.shape + out['ref_mask_latents'] = [s[0], 28, s[1], s[3], s[4]] + return out + + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key in ("sam_latents", "pose_latents"): + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + + def concat_cond(self, **kwargs): + # The 4 extra channels are the history_mask (1 at clean-anchor frames). + noise = kwargs.get("noise", None) + extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] + if extra_channels != 4: + return super().concat_cond(**kwargs) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + return torch.zeros_like(noise)[:, :4] + + device = kwargs["device"] + if mask.shape[1] != 4: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + if mask.shape[1] == 1: + mask = mask.repeat(1, 4, 1, 1, 1) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + return mask + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + # Hold anchor constant across all sigmas instead of base sigma*noise + (1-sigma)*latent_image. + return latent_image + + class WAN22_WanDancer(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 74c838d13..290938bd6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -630,6 +630,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "humo" elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "animate" + elif '{}patch_embedding_mask.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "scail2" elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "scail" elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7cf9c133b..42325d71c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1450,6 +1450,17 @@ class WAN21_SCAIL(WAN21_T2V): out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device) return out + +class WAN21_SCAIL2(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "scail2", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_SCAIL2(self, image_to_video=False, device=device) + return out + class WAN22_WanDancer(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -2259,6 +2270,7 @@ models = [ WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, + WAN21_SCAIL2, WAN22_WanDancer, Hunyuan3Dv2mini, Hunyuan3Dv2, diff --git a/comfy_extras/nodes_scail.py b/comfy_extras/nodes_scail.py new file mode 100644 index 000000000..a740442de --- /dev/null +++ b/comfy_extras/nodes_scail.py @@ -0,0 +1,321 @@ +"""SCAIL / SCAIL-2 nodes: the WanSCAILToVideo conditioning node and the SAM3 +preprocessing that turns video tracks into the bundle the SCAIL-2 model consumes.""" + +from typing_extensions import override + +import torch +import torch.nn.functional as F + +import nodes +import node_helpers +import comfy.model_management +import comfy.utils +from comfy_api.latest import ComfyExtension, io +from comfy.ldm.sam3.tracker import unpack_masks + +SAM3TrackData = io.Custom("SAM3_TRACK_DATA") + + +# Model was trained on these exact colors; deviating degrades multi-identity quality. +DEFAULT_PALETTE = [ + (0.0, 0.0, 1.0), # Blue + (1.0, 0.0, 0.0), # Red + (0.0, 1.0, 0.0), # Green + (1.0, 0.0, 1.0), # Magenta + (0.0, 1.0, 1.0), # Cyan + (1.0, 1.0, 0.0), # Yellow +] + + +def _unpack(track_data): + packed = track_data["packed_masks"] + if packed is None or packed.shape[1] == 0: + return None + return unpack_masks(packed) + + +def _first_frame_cx_area(masks_bool): + first = masks_bool[0].float() + H, W = first.shape[-2], first.shape[-1] + n_pixels = H * W + grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W) + area = first.sum(dim=(-1, -2)).clamp_(min=1) + cx = (first * grid_x).sum(dim=(-1, -2)) / area + return (cx / W).tolist(), (area / n_pixels).tolist() + + +def _subset_track_data(track_data, obj_indices): + out = dict(track_data) + packed = track_data["packed_masks"] + if packed is None or not obj_indices: + out["packed_masks"] = None + if "scores" in out: + out["scores"] = [] + return out + out["packed_masks"] = packed[:, obj_indices].contiguous() + scores = track_data.get("scores") + if scores is not None: + out["scores"] = [scores[i] for i in obj_indices if i < len(scores)] + return out + + +def _render_colored_masks(track_data, background="black"): + packed = track_data["packed_masks"] + H, W = track_data["orig_size"] + device = comfy.model_management.intermediate_device() + dtype = comfy.model_management.intermediate_dtype() + bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0) + if packed is None or packed.shape[1] == 0: + T = track_data.get("n_frames", 1) if packed is None else packed.shape[0] + out = torch.empty(T, H, W, 3, device=device, dtype=dtype) + out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2] + return out + T, N_obj = packed.shape[0], packed.shape[1] + colors = torch.tensor( + [DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)], + device=device, dtype=dtype, + ) + masks_full = unpack_masks(packed.to(device)).float() + Hm, Wm = masks_full.shape[-2], masks_full.shape[-1] + masks_full = F.interpolate( + masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest" + ).view(T, N_obj, H, W) > 0.5 + any_mask = masks_full.any(dim=1) + obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1) + color_overlay = colors[obj_idx_map] + bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3) + return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay)) + + +def _extract_mask_to_28ch(rgb_video): + """Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent + (1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c) + threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking.""" + T, H, W, _ = rgb_video.shape + _ON_THRESH = 225.0 / 255.0 + mask = rgb_video.movedim(-1, 1).float() + R = (mask[:, 0:1] > _ON_THRESH).float() + G = (mask[:, 1:2] > _ON_THRESH).float() + B = (mask[:, 2:3] > _ON_THRESH).float() + nR, nG, nB = 1 - R, 1 - G, 1 - B + binary_7ch = torch.cat([ + R * G * B, # white + R * nG * nB, # red + nR * G * nB, # green + nR * nG * B, # blue + R * G * nB, # yellow + R * nG * B, # magenta + nR * G * B, # cyan + ], dim=1) + H_lat, W_lat = H, W + for _ in range(3): + H_lat = (H_lat + 1) // 2 + W_lat = (W_lat + 1) // 2 + binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area') + T_latent = (T - 1) // 4 + 1 + padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0) + out = padded.view(T_latent, 28, H_lat, W_lat) + return out.unsqueeze(0) + + +class WanSCAILToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSCAILToVideo", + category="model/conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), + io.Image.Input("pose_video_mask", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video."), + io.Boolean.Input("replacement_mode", default=False, optional=True, tooltip="SCAIL-2 only. False = Animation Mode (pose_video_mask should have black background). True = Replacement Mode (pose_video_mask should have white background)."), + io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), + io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."), + io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."), + io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."), + io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask at the same resolution as reference_image."), + io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."), + io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."), + io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."), + io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the extension anchor."), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), + io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, + video_frame_offset, previous_frame_count, replacement_mode=False, reference_image=None, clip_vision_output=None, pose_video=None, + pose_video_mask=None, reference_image_mask=None, previous_frames=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + noise_mask = None + + ref_mask_flag = not replacement_mode + positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag}) + negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag}) + + prev_trimmed = None + if previous_frames is not None and previous_frames.shape[0] > 0: + prev_trimmed = previous_frames[-previous_frame_count:] + video_frame_offset -= prev_trimmed.shape[0] + video_frame_offset = max(0, video_frame_offset) + + ref_latent = None + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + # Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte + if replacement_mode and reference_image_mask is not None: + rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1) + is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype) + reference_image = reference_image * is_char + ref_latent = vae.encode(reference_image[:, :, :, :3]) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + if pose_video is not None: + if pose_video.shape[0] <= video_frame_offset: + pose_video = None + else: + pose_video = pose_video[video_frame_offset:] + if pose_video_mask is not None: + if pose_video_mask.shape[0] <= video_frame_offset: + pose_video_mask = None + else: + pose_video_mask = pose_video_mask[video_frame_offset:] + + # Truncate pose+mask jointly to the shorter of the two, capped at length. + ts = [v.shape[0] for v in (pose_video, pose_video_mask) if v is not None] + if ts: + T_kept = ((min(min(ts), length) - 1) // 4) * 4 + 1 + if pose_video is not None: + pose_video = pose_video[:T_kept] + if pose_video_mask is not None: + pose_video_mask = pose_video_mask[:T_kept] + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength + positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + + if pose_video_mask is not None: + mask_video_hw = comfy.utils.common_upscale(pose_video_mask[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + driving_mask_28ch = _extract_mask_to_28ch(mask_video_hw) + positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch}) + negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch}) + + if reference_image_mask is not None: + ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw) + zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype) + ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1) + positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch}) + negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch}) + + if prev_trimmed is not None: + pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + prev_latent = vae.encode(pf[:, :, :, :3]) + prev_latent_frames = min(prev_latent.shape[2], latent.shape[2]) + latent[:, :, :prev_latent_frames] = prev_latent[:, :, :prev_latent_frames].to(latent.dtype) + noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=latent.device, dtype=latent.dtype) + noise_mask[:, :, :prev_latent_frames] = 0.0 + + out_latent = {"samples": latent} + if noise_mask is not None: + out_latent["noise_mask"] = noise_mask + return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length) + + +class SCAIL2ColoredMask(io.ComfyNode): + """Render SAM3 tracks for the driving pose video and (optionally) the reference + image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by` + across both outputs guarantees identity K maps to the same color on both + sides, for multi-person workflow consistency. + reference_image_mask is always rendered black-bg (model convention) + pose_video_mask bg follows replacement_mode: black = Animation Mode, white = Replacement Mode + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SCAIL2ColoredMask", + display_name="Create SCAIL-2 Colored Mask", + category="conditioning/video_models/scail", + inputs=[ + SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."), + SAM3TrackData.Input("ref_track_data", optional=True, + tooltip="SAM3 track of the reference image."), + io.String.Input("object_indices", default="", + tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."), + io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right", + tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."), + io.Boolean.Input("replacement_mode", default=False, + tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. reference_image_mask is always black-bg regardless."), + ], + outputs=[ + io.Image.Output("pose_video_mask"), + io.Image.Output("reference_image_mask"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None): + def _prep(td): + masks_bool = _unpack(td) + if sort_by != "none" and masks_bool is not None: + cx, area = _first_frame_cx_area(masks_bool) + if sort_by == "left_to_right": + order = sorted(range(len(cx)), key=lambda i: cx[i]) + else: # "area" + order = sorted(range(len(area)), key=lambda i: -area[i]) + td = _subset_track_data(td, order) + if object_indices.strip(): + indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] + packed = td.get("packed_masks") + n_obj = packed.shape[1] if packed is not None else 0 + indices = [i for i in indices if 0 <= i < n_obj] + td = _subset_track_data(td, indices) + return td + + drv = _prep(driving_track_data) + mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black") + + if ref_track_data is not None: + ref = _prep(ref_track_data) + reference_image_mask = _render_colored_masks(ref, "black") + else: + H, W = drv["orig_size"] + reference_image_mask = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) + + return io.NodeOutput(mask_video, reference_image_mask) + + +class SCAILExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanSCAILToVideo, + SCAIL2ColoredMask, + ] + + +async def comfy_entrypoint() -> SCAILExtension: + return SCAILExtension() diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 67d3a8443..d73be8e00 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1456,63 +1456,6 @@ class WanInfiniteTalkToVideo(io.ComfyNode): return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) -class WanSCAILToVideo(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="WanSCAILToVideo", - category="model/conditioning/video_models", - inputs=[ - io.Conditioning.Input("positive"), - io.Conditioning.Input("negative"), - io.Vae.Input("vae"), - io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), - io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), - io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), - io.Int.Input("batch_size", default=1, min=1, max=4096), - io.ClipVisionOutput.Input("clip_vision_output", optional=True), - io.Image.Input("reference_image", optional=True), - io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), - io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), - io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."), - io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."), - ], - outputs=[ - io.Conditioning.Output(display_name="positive"), - io.Conditioning.Output(display_name="negative"), - io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), - ], - is_experimental=True, - ) - - @classmethod - def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput: - latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - - ref_latent = None - if reference_image is not None: - reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - ref_latent = vae.encode(reference_image[:, :, :, :3]) - - if ref_latent is not None: - positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) - negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) - - if clip_vision_output is not None: - positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) - negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) - - if pose_video is not None: - pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) - pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength - positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) - negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) - - out_latent = {} - out_latent["samples"] = latent - return io.NodeOutput(positive, negative, out_latent) - - class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1533,7 +1476,6 @@ class WanExtension(ComfyExtension): WanAnimateToVideo, Wan22ImageToVideoLatent, WanInfiniteTalkToVideo, - WanSCAILToVideo, ] async def comfy_entrypoint() -> WanExtension: diff --git a/nodes.py b/nodes.py index 2f5a478b5..4bf768045 100644 --- a/nodes.py +++ b/nodes.py @@ -2450,6 +2450,7 @@ async def init_builtin_extra_nodes(): "nodes_rtdetr.py", "nodes_frame_interpolation.py", "nodes_sam3.py", + "nodes_scail.py", "nodes_void.py", "nodes_wandancer.py", "nodes_hidream_o1.py",