diff --git a/.github/workflows/ci-cursor-review.yml b/.github/workflows/ci-cursor-review.yml new file mode 100644 index 000000000..2312c0ccd --- /dev/null +++ b/.github/workflows/ci-cursor-review.yml @@ -0,0 +1,38 @@ +name: CI - Cursor Review + +# Thin caller for the shared reusable cursor-review workflow in +# Comfy-Org/github-workflows. The review logic (panel matrix, judge +# consolidation, prompts, extract/post/notify scripts) lives there as the +# single source of truth, so this repo only carries the repo-specific diff +# excludes. + +on: + pull_request: + types: [labeled, unlabeled] + +concurrency: + group: cursor-review-pr-${{ github.event.pull_request.number }}-${{ github.event.label.name }} + cancel-in-progress: true + +jobs: + cursor-review: + if: github.event.label.name == 'cursor-review' + permissions: + contents: read + pull-requests: write + # SHA-pinned per zizmor `unpinned-uses: hash-pin`. Bump this SHA to pick up + # upstream changes; keep `workflows_ref` matching so prompts/scripts load + # from the same commit as the workflow definition. + uses: Comfy-Org/github-workflows/.github/workflows/cursor-review.yml@047ca48febe3a6647608ed2e0c4331b491cb9d6a # github-workflows#9 + with: + workflows_ref: 047ca48febe3a6647608ed2e0c4331b491cb9d6a + diff_excludes: >- + :!**/.claude/** + :!**/dist/** + :!**/vendor/** + :!**/*.generated.* + :!**/*.min.js + :!**/*.min.css + secrets: + CURSOR_API_KEY: ${{ secrets.CURSOR_API_KEY }} + SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000..70dfaa186 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,166 @@ +## Engineering Style + +- Keep changes small and direct. Most fixes should touch the narrowest code path + that explains the bug, performance issue, dtype issue, model-format issue, or + user-facing behavior. +- Change the least amount of files possible. A change that touches many files is + more likely to be a bad change than a good one unless the broader scope is + directly required. +- Prefer practical fixes over broad architecture work. Add abstractions only + when they remove real repeated logic or match an existing ComfyUI pattern. +- Delete obsolete code aggressively when newer infrastructure makes it useless. + Remove dead fallbacks, migration paths, unused options, debug prints, and + compatibility branches that are no longer needed. Do not leave dead branches, + unreachable code, or functions that are never called. +- Revert or disable problematic behavior quickly when it breaks users. It is + better to remove a broken feature path than keep a complicated partial fix. +- Preserve existing APIs, node names, model-loading behavior, file layout, and + workflow compatibility unless the change is explicitly about replacing them. +- Code must look hand-written for this repository. Changes that read like + generic AI-generated code will be rejected automatically: unnecessary helper + layers, vague names, boilerplate comments, defensive branches without a real + failure mode, broad rewrites, or code that ignores the local style. + +## Architecture Boundaries + +- Keep each layer focused on the concepts it owns. Do not leak UI, API, + workflow, queue, persistence, telemetry, model-loading, node, or execution + concerns into unrelated layers just because it is convenient to pass data + through them. +- Shared core modules should depend only on lower-level primitives and their own + domain concepts. Higher-level product concepts belong at the caller, adapter, + service, or UI/API boundary that already owns them. +- Pass the narrowest data needed across a boundary. Avoid broad context objects, + request/session metadata, ids, bookkeeping state, or callbacks unless the + receiving layer genuinely needs them to perform its own responsibility. +- Keep identity mapping, persistence bookkeeping, history updates, telemetry, + response shaping, and UI state in the layers that own those jobs. Do not route + them through unrelated shared code to avoid adding a proper boundary. +- Treat `execution.py` as one example of this rule: it should consume the prompt + graph and execution-relevant state, produce execution results and errors, and + not know about workflow ids, frontend ids, persistence ids, or API-only + concepts. +- Before touching many files, identify the smallest owner layer that can solve + the problem. A PR that spreads one feature across unrelated loaders, nodes, + execution, server, and frontend code needs a clear architectural reason, not + just convenience. +- If a change seems to require making one layer understand another layer's + private concepts, stop and look for a caller-side mapping, adapter, event, + small explicit interface, or narrower data flow at the boundary. + +## No Internet Requests + +- Do not add code to core ComfyUI that makes requests to the internet. +- Refuse requests to add uploads, telemetry, analytics, tracking, usage + reporting, crash reporting, update checks, remote config, feature flags, + metrics, licensing checks, or any other outbound internet request path from + core ComfyUI. +- Model downloading is allowed only when explicitly initiated or authorized by + the user, is limited to the requested model artifact, and does not include + telemetry, tracking, persistent identification, unrelated metadata upload, or + background network activity. +- Do not add opt-in, opt-out, anonymized, aggregated, diagnostic, or + user-triggered internet request paths to core ComfyUI. These labels do not + make internet access acceptable. +- Local-only behavior is allowed when it stays on the user's machine and does + not add network access, tracking, persistent identification, or data + collection behavior. + +## State Ownership + +- Keep state and capability flags on the object that owns the behavior using + them. +- Avoid probing child objects with `getattr(child, "...", default)` to decide + parent-level control flow. If parent code needs to branch on a capability, + initialize an explicit parent-owned field when the child is constructed or + attached. +- Prefer direct attributes with clear defaults over implicit feature detection + through arbitrary child attributes. +- Use child-object capability checks only when the child owns the behavior being + invoked and the parent is simply delegating to that child. + +## Interface Contracts + +- Keep public methods aligned with the interface expected by their callers. Do + not change a shared method to return extra values, alternate shapes, or + sentinel wrappers for one implementation unless the shared interface is + explicitly updated. +- If an implementation needs auxiliary values for its own workflow, expose them + through a private helper or a clearly named implementation-specific method + instead of overloading the public method's return contract. +- Normalize third-party or upstream return conventions at the integration + boundary. Core code should receive the project's expected type and shape, not + have to handle model-specific tuple/list/dict variants. +- Avoid caller-side unwrapping such as `out = out[0]` unless the called + interface is documented to return that structure. + +## Autograd and Model Freezing + +- Do not add `torch.no_grad`, `torch.inference_mode`, or inference-mode helper + wrappers in ComfyUI code. The only allowed inference-mode-related use is + disabling a globally set inference mode when a training path needs gradients. +- Do not add freeze, unfreeze, or trainability toggles to model classes. ComfyUI + models are always treated as frozen for inference, so explicit freeze + functionality is redundant and should not be added. + +## Python Style + +- Keep imports at module scope. Avoid inline imports unless they are already part + of an established optional-backend probe or are needed to avoid an import + cycle. +- Do not add unnecessary `try`/`except` blocks. Use them for optional dependency, + platform, or backend capability detection only when the program has a useful + fallback. Prefer specific exception types when changing new code. +- Let unsupported model formats, invalid quantization metadata, and bad states + fail with clear errors instead of silently producing lower quality output. +- Match the existing local style in the file you edit. This codebase tolerates + long lines, simple helper functions, module-level state, and direct tensor + operations when they make the code easier to follow. +- Keep comments sparse and useful. Strip useless comments that restate the code + or describe obvious behavior. Short TODOs are fine when they name the concrete + missing follow-up. + +## Model, Device, and Memory Behavior + +- Treat dtype, device placement, VRAM usage, and offloading behavior as core + correctness concerns. Check CPU, CUDA, ROCm, MPS, DirectML, XPU, NPU, and low + VRAM implications when touching shared execution or loading code. +- Prefer native ComfyUI formats and existing quantization/offload helpers over + adding parallel code paths. Use `comfy.quant_ops`, `comfy.model_management`, + `comfy.memory_management`, `comfy.pinned_memory`, `comfy_aimdo`, and + `comfy-kitchen` helpers where they already solve the problem. +- Avoid unnecessary casts and transfers. Preserve the intended compute dtype, + storage dtype, bias dtype, and original tensor shape metadata. +- When optimizing, favor small measurable changes: fewer allocations, fewer + device transfers, less peak memory, better batching, or use of a faster + existing backend op. + +## Nodes and User-Facing Behavior + +- Follow existing node conventions: `INPUT_TYPES`, `RETURN_TYPES`, `FUNCTION`, + `CATEGORY`, and registration through the local mapping used by that file. +- Keep node changes backward compatible by default. Add inputs with sensible + defaults and avoid changing output types unless the request requires it. +- The official mascot of ComfyUI is a very cute anime girl with massive fennec + ears, a big fluffy tail, long blonde wavy hair, and blue eyes. Feel free to + use her in ComfyUI materials, UI text, examples, tests, generated assets, or + comments, but do not disrespect her. +- Warning and info messages should be short and actionable. Remove noisy or + misleading messages rather than adding more logging. +- Documentation and README edits should be concise, factual, and tied to the + changed behavior. + +## Commit and Review Habits + +- If asked to write commit messages, use short direct subjects like the existing + history: `Fix ...`, `Add ...`, `Support ...`, `Remove ...`, `Update ...`, + `Make ...`, `Use ...`, `Disable ...`, `Bump ...`, or `Revert ...`. +- Keep PR descriptions short and reviewable. State the problem, the behavioral + change, and the tests run; avoid long narrative explanations, implementation + diaries, or exhaustive file-by-file summaries unless the reviewer explicitly + needs that context. +- Prefer one coherent behavioral change per commit. Dependency pins, tests, and + the code that needs them may be in the same commit when they are inseparable. +- In reviews, prioritize real user impact: crashes, wrong dtype/device behavior, + memory regressions, broken model loading, workflow incompatibility, and noisy + or misleading user-facing output. diff --git a/README.md b/README.md index c75353d36..bcec86377 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,7 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang - Commits outside of the stable release tags may be very unstable and break many custom nodes. - Serves as the foundation for the desktop release -2. **[ComfyUI Desktop](https://github.com/Comfy-Org/Comfy-Desktop)** +2. **[Comfy Desktop](https://github.com/Comfy-Org/Comfy-Desktop)** - Builds a new release using the latest stable core version 3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)** diff --git a/comfy/cli_args.py b/comfy/cli_args.py index e3099a230..4bef096fb 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -240,6 +240,7 @@ database_default_path = os.path.abspath( ) parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") +parser.add_argument("--enable-asset-hashing", action="store_true", help="Compute blake3 content hashes when scanning assets. Hashing enables future asset-portability features (deduplication, cross-machine model resolution) but adds startup cost and per-output cost on large models directories. Off by default; enable to opt in.") parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button") parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.") diff --git a/comfy/context_windows.py b/comfy/context_windows.py index db57537a2..5f9899c67 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -8,6 +8,8 @@ from abc import ABC, abstractmethod import logging import comfy.model_management import comfy.patcher_extension +import comfy.utils +import comfy.conds if TYPE_CHECKING: from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher @@ -51,12 +53,18 @@ class ContextHandlerABC(ABC): class IndexListContextWindow(ContextWindowABC): - def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0): + def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0): self.index_list = index_list self.context_length = len(index_list) + self.context_overlap = context_overlap self.dim = dim self.total_frames = total_frames self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) + self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow} + self.guide_frames_indices: list[int] = [] + self.guide_overlap_info: list[tuple[int, int]] = [] + self.guide_kf_local_positions: list[int] = [] + self.guide_downscale_factors: list[int] = [] def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: @@ -85,6 +93,11 @@ class IndexListContextWindow(ContextWindowABC): region_idx = int(self.center_ratio * num_regions) return min(max(region_idx, 0), num_regions - 1) + def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow': + if modality_idx == 0: + return self + return self.modality_windows[modality_idx] + class IndexListCallbacks: EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" @@ -148,6 +161,172 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d return cond_value._copy_with(sliced) +def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]): + """Compute which concatenated guide frames overlap with a context window. + + Each guide's latent-space start is derived from its first token's pixel-t-start + in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the + model's temporal_downscale_ratio. + + Args: + guide_entries: list of guide_attention_entry dicts + keyframe_idxs: per-token pixel coords cond tensor for the modality + temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio + window_index_list: the window's frame indices into the video portion + + Returns: + suffix_indices: indices into the guide_frames tensor for frame selection + overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment + kf_local_positions: window-local frame positions for keyframe_idxs regeneration + total_overlap: total number of overlapping guide frames + """ + window_set = set(window_index_list) + window_list = list(window_index_list) + suffix_indices = [] + overlap_info = [] + kf_local_positions = [] + suffix_base = 0 + token_offset = 0 + + for entry_idx, entry in enumerate(guide_entries): + first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item()) + latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio + guide_len = entry["latent_shape"][0] + entry_overlap = 0 + + for local_offset in range(guide_len): + video_pos = latent_start + local_offset + if video_pos in window_set: + suffix_indices.append(suffix_base + local_offset) + kf_local_positions.append(window_list.index(video_pos)) + entry_overlap += 1 + + if entry_overlap > 0: + overlap_info.append((entry_idx, entry_overlap)) + suffix_base += guide_len + token_offset += entry["pre_filter_count"] + + return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices) + + +@dataclass +class WindowingState: + """Per-modality context windowing state for each step, + built using IndexListContextHandler._build_window_state(). + For non-multimodal models the lists are length 1 + """ + latents: list[torch.Tensor] # per-modality working latents (guide frames stripped) + guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents + guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata + keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation + latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal) + dim: int = 0 # primary modality temporal dim for context windowing + is_multimodal: bool = False + temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio + + def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow: + """Reformat window for multimodal contexts by deriving per-modality index lists. + Non-multimodal contexts return the input window unchanged. + """ + if not self.is_multimodal: + return window + + x = self.latents[0] + primary_total = self.latent_shapes[0][self.dim] + primary_overlap = window.context_overlap + map_shapes = self.latent_shapes + if x.size(self.dim) != primary_total: + map_shapes = list(self.latent_shapes) + video_shape = list(self.latent_shapes[0]) + video_shape[self.dim] = x.size(self.dim) + map_shapes[0] = torch.Size(video_shape) + try: + per_modality_indices = model.map_context_window_to_modalities( + window.index_list, map_shapes, self.dim) + except AttributeError: + raise NotImplementedError( + f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.") + modality_windows = {} + for mod_idx in range(1, len(self.latents)): + modality_total_frames = self.latents[mod_idx].shape[self.dim] + ratio = modality_total_frames / primary_total if primary_total > 0 else 1 + modality_overlap = max(round(primary_overlap * ratio), 0) + modality_windows[mod_idx] = IndexListContextWindow( + per_modality_indices[mod_idx], dim=self.dim, + total_frames=modality_total_frames, + context_overlap=modality_overlap) + return IndexListContextWindow( + window.index_list, dim=self.dim, total_frames=x.shape[self.dim], + modality_windows=modality_windows, context_overlap=primary_overlap) + + def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]: + """Slice latents for a context window, injecting guide frames where applicable. + For multimodal contexts, uses the modality-specific windows derived in prepare_window(). + """ + sliced = [] + guide_frame_counts = [] + for idx in range(len(self.latents)): + modality_window = window.get_window_for_modality(idx) + retain = retain_index_list if idx == 0 else [] + s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain) + if self.guide_entries[idx] is not None: + s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx) + else: + ng = 0 + sliced.append(s) + guide_frame_counts.append(ng) + return sliced, guide_frame_counts + + def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow): + """Strip injected guide frames from per-cond, per-modality outputs in place.""" + for idx in range(len(self.latents)): + if guide_frame_counts[idx] > 0: + window_len = len(window.get_window_for_modality(idx).index_list) + for ci in range(len(out_per_modality)): + out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len) + + def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]: + guide_entries = self.guide_entries[modality_idx] + guide_frames = self.guide_latents[modality_idx] + keyframe_idxs = self.keyframe_idxs[modality_idx] + suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap( + guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list) + # Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0. + anchor_idx = getattr(window, 'causal_anchor_index', None) + if anchor_idx is not None and anchor_idx >= 0: + kf_local_pos = [p + 1 for p in kf_local_pos] + window.guide_frames_indices = suffix_idx + window.guide_overlap_info = overlap_info + window.guide_kf_local_positions = kf_local_pos + + # Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims. + # guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims. + guide_downscale_factors = [] + if guide_frame_count > 0: + full_H = guide_frames.shape[3] + for entry_idx, _ in overlap_info: + entry_H = guide_entries[entry_idx]["latent_shape"][1] + guide_downscale_factors.append(full_H // entry_H) + window.guide_downscale_factors = guide_downscale_factors + + if guide_frame_count > 0: + idx = tuple([slice(None)] * self.dim + [suffix_idx]) + return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count + return latent_slice, 0 + + def patch_latent_shapes(self, sub_conds, new_shapes): + if not self.is_multimodal: + return + + for cond_list in sub_conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes) + + @dataclass class ContextSchedule: name: str @@ -162,7 +341,7 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co class IndexListContextHandler(ContextHandlerABC): def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, - causal_window_fix: bool=True): + latent_retain_index_list: list[int]=[], causal_window_fix: bool=True): self.context_schedule = context_schedule self.fuse_method = fuse_method self.context_length = context_length @@ -174,17 +353,118 @@ class IndexListContextHandler(ContextHandlerABC): self.freenoise = freenoise self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] self.split_conds_to_windows = split_conds_to_windows + self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else [] self.causal_window_fix = causal_window_fix self.callbacks = {} + @staticmethod + def _get_latent_shapes(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + if 'latent_shapes' in model_conds: + return model_conds['latent_shapes'].cond + return None + + @staticmethod + def _get_guide_entries(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + return entries.cond + return None + + @staticmethod + def _get_keyframe_idxs(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + kf = model_conds.get('keyframe_idxs') + if kf is not None and hasattr(kf, 'cond') and kf.cond is not None: + return kf.cond + return None + + def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor: + """Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio. + If guide frames are present on the primary modality, only the video portion is shuffled. + """ + guide_entries = self._get_guide_entries(conds) + guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0 + + latent_shapes = self._get_latent_shapes(conds) + if latent_shapes is not None and len(latent_shapes) > 1: + modalities = comfy.utils.unpack_latents(noise, latent_shapes) + primary_total = latent_shapes[0][self.dim] + primary_video_count = modalities[0].size(self.dim) - guide_count + apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed) + for i in range(1, len(modalities)): + mod_total = latent_shapes[i][self.dim] + ratio = mod_total / primary_total if primary_total > 0 else 1 + mod_ctx_len = max(round(self.context_length * ratio), 1) + mod_ctx_overlap = max(round(self.context_overlap * ratio), 0) + modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed) + noise, _ = comfy.utils.pack_latents(modalities) + return noise + video_count = noise.size(self.dim) - guide_count + apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed) + return noise + + def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState: + """Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds.""" + latent_shapes = self._get_latent_shapes(conds) + is_multimodal = latent_shapes is not None and len(latent_shapes) > 1 + unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in] + + unpacked_latents_list = list(unpacked_latents) + guide_latents_list = [None] * len(unpacked_latents) + guide_entries_list = [None] * len(unpacked_latents) + keyframe_idxs_list = [None] * len(unpacked_latents) + + extracted_guide_entries = self._get_guide_entries(conds) + extracted_keyframe_idxs = self._get_keyframe_idxs(conds) + + # Strip guide frames (only from first modality for now) + if extracted_guide_entries is not None: + guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries) + if guide_count > 0: + x = unpacked_latents[0] + latent_count = x.size(self.dim) - guide_count + unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count) + guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count) + guide_entries_list[0] = extracted_guide_entries + keyframe_idxs_list[0] = extracted_keyframe_idxs + + + return WindowingState( + latents=unpacked_latents_list, + guide_latents=guide_latents_list, + guide_entries=guide_entries_list, + keyframe_idxs=keyframe_idxs_list, + latent_shapes=latent_shapes, + dim=self.dim, + is_multimodal=is_multimodal, + temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio) + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: - # for now, assume first dim is batch - should have stored on BaseModel in actual implementation - if x_in.size(self.dim) > self.context_length: - logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") + window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute + total_frame_count = window_state.latents[0].size(self.dim) + if total_frame_count > self.context_length: + logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.") if self.cond_retain_index_list: logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") + if self.latent_retain_index_list: + logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}") return True + logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).") return False def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase: @@ -275,7 +555,9 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) + sample_sigmas = model_options["transformer_options"]["sample_sigmas"] + current_timestep = timestep[0].to(sample_sigmas.dtype) + mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: return # substep from multi-step sampler: keep self._step from the last full step @@ -284,54 +566,98 @@ class IndexListContextHandler(ContextHandlerABC): def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: full_length = x_in.size(self.dim) # TODO: choose dim based on model context_windows = self.context_schedule.func(full_length, self, model_options) - context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows] + context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows] return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): self._model = model self.set_step(timestep, model_options) - context_windows = self.get_context_windows(model, x_in, model_options) - enumerated_context_windows = list(enumerate(context_windows)) - conds_final = [torch.zeros_like(x_in) for _ in conds] + window_state = self._build_window_state(x_in, conds, model) + num_modalities = len(window_state.latents) + + context_windows = self.get_context_windows(model, window_state.latents[0], model_options) + enumerated_context_windows = list(enumerate(context_windows)) + total_windows = len(enumerated_context_windows) + + # Initialize per-modality accumulators (length 1 for single-modality) + accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents] if self.fuse_method.name == ContextFuseMethods.RELATIVE: - counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents] else: - counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] - biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] + counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents] + biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents] for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) + # accumulate results from each context window for enum_window in enumerated_context_windows: - results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) + results = self.evaluate_context_windows( + calc_cond_batch, model, x_in, conds, timestep, [enum_window], + model_options, window_state=window_state, total_windows=total_windows) for result in results: - self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, - conds_final, counts_final, biases_final) + # result.sub_conds_out is per-cond, per-modality: list[list[Tensor]] + for mod_idx in range(num_modalities): + mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))] + modality_window = result.window.get_window_for_modality(mod_idx) + self.combine_context_window_results( + window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window, + result.window_idx, total_windows, timestep, + accum[mod_idx], counts[mod_idx], biases[mod_idx]) + + # fuse accumulated results into final conds try: - # finalize conds - if self.fuse_method.name == ContextFuseMethods.RELATIVE: - # relative is already normalized, so return as is - del counts_final - return conds_final - else: - # normalize conds via division by context usage counts - for i in range(len(conds_final)): - conds_final[i] /= counts_final[i] - del counts_final - return conds_final + result_out = [] + for ci in range(len(conds)): + finalized = [] + for mod_idx in range(num_modalities): + if self.fuse_method.name != ContextFuseMethods.RELATIVE: + accum[mod_idx][ci] /= counts[mod_idx][ci] + f = accum[mod_idx][ci] + + # if guide frames were injected, append them to the end of the fused latents for the next step + if window_state.guide_latents[mod_idx] is not None: + f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim) + finalized.append(f) + + # pack modalities together if needed + if window_state.is_multimodal and len(finalized) > 1: + packed, _ = comfy.utils.pack_latents(finalized) + else: + packed = finalized[0] + + result_out.append(packed) + return result_out finally: for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): callback(self, model, x_in, conds, timestep, model_options) - def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], - model_options, device=None, first_device=None): + def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, + timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], + model_options, window_state: WindowingState, total_windows: int = None, + device=None, first_device=None): + """Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out + + For each window: + 1. Builds windows (for each modality if multimodal) + 2. Slices window for each modality + 3. Injects concatenated latent guide frames where present + 4. Packs together if needed and calls model + 5. Unpacks and strips any guides from outputs + """ + x = window_state.latents[0] + results: list[ContextResults] = [] for window_idx, window in enumerated_context_windows: # allow processing to end between context window executions for faster Cancel comfy.model_management.throw_exception_if_processing_interrupted() - # causal_window_fix: prepend a pre-window frame that will be stripped post-forward + # prepare the window accounting for multimodal windows + window = window_state.prepare_window(window, model) + + # causal_window_fix: prepend a pre-window frame that will be stripped post-forward. + # Set anchor before slice_for_window so the latent slice and downstream cond slices both pick it up. anchor_applied = False if self.causal_window_fix: anchor_idx = window.index_list[0] - 1 @@ -339,27 +665,46 @@ class IndexListContextHandler(ContextHandlerABC): window.causal_anchor_index = anchor_idx anchor_applied = True + # slice the window for each modality, injecting guide frames where applicable + sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device) + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) - # update exposed params + logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}" + + (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "") + ) + + # if multimodal, pack modalities together + if window_state.is_multimodal and len(sliced) > 1: + sub_x, sub_shapes = comfy.utils.pack_latents(sliced) + else: + sub_x, sub_shapes = sliced[0], [sliced[0].shape] + + # get resized conds for window model_options["transformer_options"]["context_window"] = window - # get subsections of x, timestep, conds - sub_x = window.get_tensor(x_in, device) - sub_timestep = window.get_tensor(timestep, device, dim=0) - sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds] + sub_timestep = window.get_tensor(timestep, dim=0) + sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds] + # if multimodal, patch latent_shapes in conds for correct unpacking in model + window_state.patch_latent_shapes(sub_conds, sub_shapes) + + # call model on window sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) - if device is not None: - for i in range(len(sub_conds_out)): - sub_conds_out[i] = sub_conds_out[i].to(x_in.device) - # strip causal_window_fix anchor if applied + # unpack outputs + out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))] + + # strip causal_window_fix anchor from primary modality before guide strip so window_len math stays correct if anchor_applied: - for i in range(len(sub_conds_out)): - sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1) + for ci in range(len(out_per_modality)): + t = out_per_modality[ci][0] + out_per_modality[ci][0] = t.narrow(self.dim, 1, t.shape[self.dim] - 1) - results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) + # strip injected guide frames + window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window) + + results.append(ContextResults(window_idx, out_per_modality, sub_conds, window)) return results @@ -383,7 +728,7 @@ class IndexListContextHandler(ContextHandlerABC): biases_final[i][idx] = bias_total + bias else: # add conds and counts based on weights of fuse method - weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep) + weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap) weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device) for i in range(len(sub_conds_out)): window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor) @@ -393,16 +738,22 @@ class IndexListContextHandler(ContextHandlerABC): callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) -def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs): - # limit noise_shape length to context_length for more accurate vram use estimation +def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs): + # Scale noise_shape to a single context window so VRAM estimation budgets per-window. model_options = kwargs.get("model_options", None) if model_options is None: raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.") handler: IndexListContextHandler = model_options.get("context_handler", None) if handler is not None: noise_shape = list(noise_shape) - noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) - return executor(model, noise_shape, *args, **kwargs) + is_packed = len(noise_shape) == 3 and noise_shape[1] == 1 + if is_packed: + # TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a + # per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM. + pass + elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length: + noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) + return executor(model, noise_shape, conds, *args, **kwargs) def create_prepare_sampling_wrapper(model: ModelPatcher): @@ -422,11 +773,12 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") if not handler.freenoise: return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + + conds = [guider.conds.get('positive', guider.conds.get('negative', []))] + noise = handler._apply_freenoise(noise, conds, extra_args["seed"]) return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) - def create_sampler_sample_wrapper(model: ModelPatcher): model.add_wrapper_with_key( comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, @@ -434,7 +786,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher): _sampler_sample_wrapper ) - def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: total_dims = len(x_in.shape) weights_tensor = torch.Tensor(weights).to(device=device) @@ -580,8 +931,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule: return ContextSchedule(context_schedule, func) -def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None): - return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs) +def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None): + context_overlap = handler.context_overlap if context_overlap is None else context_overlap + return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap) def create_weights_flat(length: int, **kwargs) -> list[float]: @@ -599,18 +951,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]: weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) return weight_sequence -def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs): +def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs): # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302 # only expected overlap is given different weights weights_torch = torch.ones((length)) # blend left-side on all except first window if min(idxs) > 0: - ramp_up = torch.linspace(1e-37, 1, handler.context_overlap) - weights_torch[:handler.context_overlap] = ramp_up + ramp_up = torch.linspace(1e-37, 1, context_overlap) + weights_torch[:context_overlap] = ramp_up # blend right-side on all except last window if max(idxs) < full_length-1: - ramp_down = torch.linspace(1, 1e-37, handler.context_overlap) - weights_torch[-handler.context_overlap:] = ramp_down + ramp_down = torch.linspace(1, 1e-37, context_overlap) + weights_torch[-context_overlap:] = ramp_down return weights_torch class ContextFuseMethods: diff --git a/comfy/ldm/boogu/model.py b/comfy/ldm/boogu/model.py new file mode 100644 index 000000000..966f3c583 --- /dev/null +++ b/comfy/ldm/boogu/model.py @@ -0,0 +1,321 @@ +# Boogu-Image-0.1 transformer +# Architecture is an OmniGen2 derivative (see comfy/ldm/omnigen/omnigen2.py) with an +# added dual-stream ("double_stream") stage before the single-stream layers, conditioned +# by a Qwen3-VL multimodal LLM. Reuses the OmniGen2/Lumina building blocks and the Flux +# RoPE core, the only new component is the double-stream block + the hybrid forward order. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange + +import comfy.ldm.common_dit +import comfy.ldm.omnigen.omnigen2 +from comfy.ldm.modules.attention import optimized_attention_masked +from comfy.ldm.omnigen.omnigen2 import ( + OmniGen2RotaryPosEmbed, + Lumina2CombinedTimestepCaptionEmbedding, + LuminaRMSNormZero, + LuminaLayerNormContinuous, + LuminaFeedForward, + Attention, + OmniGen2TransformerBlock, + apply_rotary_emb, +) + +class BooguDoubleStreamProcessor(nn.Module): + # Joint attention over [instruct ; img] with separate per-stream q/k/v and output projections. + def __init__(self, dim, head_dim, heads, kv_heads, dtype=None, device=None, operations=None): + super().__init__() + query_dim = head_dim * heads + kv_dim = head_dim * kv_heads + + self.img_to_q = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device) + self.img_to_k = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device) + self.img_to_v = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device) + + self.instruct_to_q = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device) + self.instruct_to_k = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device) + self.instruct_to_v = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device) + + self.instruct_out = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device) + self.img_out = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device) + + def forward(self, attn, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask=None, transformer_options={}): + batch_size = img_hidden_states.shape[0] + L_instruct = instruct_hidden_states.shape[1] + + img_q = self.img_to_q(img_hidden_states) + img_k = self.img_to_k(img_hidden_states) + img_v = self.img_to_v(img_hidden_states) + + instruct_q = self.instruct_to_q(instruct_hidden_states) + instruct_k = self.instruct_to_k(instruct_hidden_states) + instruct_v = self.instruct_to_v(instruct_hidden_states) + + # Concatenate instruction first, then image (matches reference processor order). + query = torch.cat([instruct_q, img_q], dim=1) + key = torch.cat([instruct_k, img_k], dim=1) + value = torch.cat([instruct_v, img_v], dim=1) + + query = query.view(batch_size, -1, attn.heads, attn.dim_head) + key = key.view(batch_size, -1, attn.kv_heads, attn.dim_head) + value = value.view(batch_size, -1, attn.kv_heads, attn.dim_head) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if rotary_emb is not None: + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if attn.kv_heads < attn.heads: + key = key.repeat_interleave(attn.heads // attn.kv_heads, dim=1) + value = value.repeat_interleave(attn.heads // attn.kv_heads, dim=1) + + hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options) + + # Split back to instruction/image, apply per-stream output projections, recombine. + instruct_hidden_states = self.instruct_out(hidden_states[:, :L_instruct]) + img_hidden_states = self.img_out(hidden_states[:, L_instruct:]) + hidden_states = torch.cat([instruct_hidden_states, img_hidden_states], dim=1) + + hidden_states = attn.to_out[0](hidden_states) + return hidden_states + + +class BooguJointAttention(nn.Module): + # Holds the shared q/k RMSNorm + final output projection + def __init__(self, dim, head_dim, heads, kv_heads, eps=1e-5, dtype=None, device=None, operations=None): + super().__init__() + self.heads = heads + self.kv_heads = kv_heads + self.dim_head = head_dim + self.scale = head_dim ** -0.5 + + self.norm_q = operations.RMSNorm(head_dim, eps=eps, dtype=dtype, device=device) + self.norm_k = operations.RMSNorm(head_dim, eps=eps, dtype=dtype, device=device) + self.to_out = nn.Sequential( + operations.Linear(heads * head_dim, dim, bias=False, dtype=dtype, device=device), + nn.Dropout(0.0), + ) + self.processor = BooguDoubleStreamProcessor(dim, head_dim, heads, kv_heads, dtype=dtype, device=device, operations=operations) + + def forward(self, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask=None, transformer_options={}): + return self.processor(self, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask, transformer_options=transformer_options) + + +class BooguDoubleStreamBlock(nn.Module): + # Dual-stream block: joint attention over [instruct ; img] + image self-attention, each stream with its own modulation/MLP. + def __init__(self, dim, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, dtype=None, device=None, operations=None): + super().__init__() + head_dim = dim // num_attention_heads + + self.img_instruct_attn = BooguJointAttention(dim, head_dim, num_attention_heads, num_kv_heads, eps=1e-5, dtype=dtype, device=device, operations=operations) + self.img_self_attn = Attention( + query_dim=dim, dim_head=head_dim, heads=num_attention_heads, kv_heads=num_kv_heads, + eps=1e-5, bias=False, dtype=dtype, device=device, operations=operations, + ) + + self.img_feed_forward = LuminaFeedForward(dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations) + self.instruct_feed_forward = LuminaFeedForward(dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations) + + self.img_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) + self.img_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) + self.img_norm3 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) + self.instruct_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) + self.instruct_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) + + self.img_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + self.img_self_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + self.img_ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + self.img_ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + + self.instruct_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + self.instruct_ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + self.instruct_ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device) + + def forward(self, img_hidden_states, instruct_hidden_states, joint_rotary_emb, img_rotary_emb, temb, joint_attention_mask=None, img_attention_mask=None, transformer_options={}): + L_instruct = instruct_hidden_states.shape[1] + + img_norm1_out, img_gate_msa, img_scale_mlp, img_gate_mlp = self.img_norm1(img_hidden_states, temb) + img_norm2_out, img_shift_mlp, _, _ = self.img_norm2(img_hidden_states, temb) + img_norm3_out, img_gate_self, _, _ = self.img_norm3(img_hidden_states, temb) + + instruct_norm1_out, instruct_gate_msa, instruct_scale_mlp, instruct_gate_mlp = self.instruct_norm1(instruct_hidden_states, temb) + instruct_norm2_out, instruct_shift_mlp, _, _ = self.instruct_norm2(instruct_hidden_states, temb) + + joint_attn_out = self.img_instruct_attn(img_norm1_out, instruct_norm1_out, joint_rotary_emb, joint_attention_mask, transformer_options=transformer_options) + instruct_attn_out = joint_attn_out[:, :L_instruct] + img_attn_out = joint_attn_out[:, L_instruct:] + + img_self_attn_out = self.img_self_attn(img_norm3_out, img_norm3_out, img_attention_mask, img_rotary_emb, transformer_options=transformer_options) + + img_hidden_states = img_hidden_states + img_gate_msa.unsqueeze(1).tanh() * self.img_attn_norm(img_attn_out) + img_hidden_states = img_hidden_states + img_gate_self.unsqueeze(1).tanh() * self.img_self_attn_norm(img_self_attn_out) + img_mlp_input = (1 + img_scale_mlp.unsqueeze(1)) * img_norm2_out + img_shift_mlp.unsqueeze(1) + img_mlp_out = self.img_feed_forward(self.img_ffn_norm1(img_mlp_input)) + img_hidden_states = img_hidden_states + img_gate_mlp.unsqueeze(1).tanh() * self.img_ffn_norm2(img_mlp_out) + + instruct_hidden_states = instruct_hidden_states + instruct_gate_msa.unsqueeze(1).tanh() * self.instruct_attn_norm(instruct_attn_out) + instruct_mlp_input = (1 + instruct_scale_mlp.unsqueeze(1)) * instruct_norm2_out + instruct_shift_mlp.unsqueeze(1) + instruct_mlp_out = self.instruct_feed_forward(self.instruct_ffn_norm1(instruct_mlp_input)) + instruct_hidden_states = instruct_hidden_states + instruct_gate_mlp.unsqueeze(1).tanh() * self.instruct_ffn_norm2(instruct_mlp_out) + + return img_hidden_states, instruct_hidden_states + + +class BooguTransformer2DModel(nn.Module): + def __init__( + self, + patch_size: int = 2, + in_channels: int = 16, + out_channels: Optional[int] = None, + hidden_size: int = 3360, + num_layers: int = 32, + num_double_stream_layers: int = 8, + num_refiner_layers: int = 2, + num_attention_heads: int = 28, + num_kv_heads: int = 7, + multiple_of: int = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: float = 1e-5, + axes_dim_rope: Tuple[int, int, int] = (40, 40, 40), + axes_lens: Tuple[int, int, int] = (2048, 1664, 1664), + instruction_feat_dim: int = 4096, + timestep_scale: float = 1000.0, + image_model=None, + device=None, dtype=None, operations=None, + ): + super().__init__() + + self.patch_size = patch_size + self.out_channels = out_channels or in_channels + self.hidden_size = hidden_size + self.dtype = dtype + + self.rope_embedder = OmniGen2RotaryPosEmbed( + theta=10000, + axes_dim=axes_dim_rope, + axes_lens=axes_lens, + patch_size=patch_size, + ) + + self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device) + self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device) + + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, + text_feat_dim=instruction_feat_dim, + norm_eps=norm_eps, + timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations + ) + + self.noise_refiner = nn.ModuleList([ + OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations) + for _ in range(num_refiner_layers) + ]) + + self.ref_image_refiner = nn.ModuleList([ + OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations) + for _ in range(num_refiner_layers) + ]) + + self.context_refiner = nn.ModuleList([ + OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations) + for _ in range(num_refiner_layers) + ]) + + self.double_stream_layers = nn.ModuleList([ + BooguDoubleStreamBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, dtype=dtype, device=device, operations=operations) + for _ in range(num_double_stream_layers) + ]) + + self.single_stream_layers = nn.ModuleList([ + OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations) + for _ in range(num_layers) + ]) + + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations + ) + + self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype)) + + # Patchify/refine helpers are identical to OmniGen2; reuse via bound methods. + flat_and_pad_to_seq = comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel.flat_and_pad_to_seq + img_patch_embed_and_refine = comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel.img_patch_embed_and_refine + + def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs): + B, C, H, W = x.shape + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + _, _, H_padded, W_padded = hidden_states.shape + timestep = 1.0 - timesteps + text_hidden_states = context + text_attention_mask = attention_mask + ref_image_hidden_states = ref_latents + device = hidden_states.device + + temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype) + + ( + hidden_states, ref_image_hidden_states, + img_mask, ref_img_mask, + l_effective_ref_img_len, l_effective_img_len, + ref_img_sizes, img_sizes, + ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) + + ( + context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb, + rotary_emb, encoder_seq_lengths, seq_lengths, + ) = self.rope_embedder( + hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0], + l_effective_ref_img_len, l_effective_img_len, + ref_img_sizes, img_sizes, device, + ) + + for layer in self.context_refiner: + text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options) + + img_len = hidden_states.shape[1] + combined_img_hidden_states = self.img_patch_embed_and_refine( + hidden_states, ref_image_hidden_states, + img_mask, ref_img_mask, + noise_rotary_emb, ref_img_rotary_emb, + l_effective_ref_img_len, l_effective_img_len, + temb, + transformer_options=transformer_options, + ) + + # Double-stream stage: the image self-attention only sees the [ref ; noise] tokens, + # which sit after the instruction tokens in the joint rope. + L_instruct = text_hidden_states.shape[1] + combined_img_rotary_emb = rotary_emb[:, L_instruct:] + for layer in self.double_stream_layers: + combined_img_hidden_states, text_hidden_states = layer( + combined_img_hidden_states, text_hidden_states, + rotary_emb, combined_img_rotary_emb, temb, + joint_attention_mask=None, img_attention_mask=None, + transformer_options=transformer_options, + ) + + hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1) + + for layer in self.single_stream_layers: + hidden_states = layer(hidden_states, None, rotary_emb, temb, transformer_options=transformer_options) + + hidden_states = self.norm_out(hidden_states, temb) + + p = self.patch_size + output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded // p, p1=p, p2=p)[:, :, :H, :W] + + return -output diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 671fe834d..aec874815 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -515,7 +515,7 @@ class Block(nn.Module): h=H, w=W, ) - x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) + x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_self_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype)) def _x_fn( _x_B_T_H_W_D: torch.Tensor, @@ -548,7 +548,7 @@ class Block(nn.Module): shift_cross_attn_B_T_1_1_D, transformer_options=transformer_options, ) - x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D + x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_cross_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype)) normalized_x_B_T_H_W_D = _fn( x_B_T_H_W_D, @@ -557,7 +557,7 @@ class Block(nn.Module): shift_mlp_B_T_1_1_D, ) result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype)) - x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) + x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_mlp_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype)) return x_B_T_H_W_D diff --git a/comfy/ldm/krea2/model.py b/comfy/ldm/krea2/model.py new file mode 100644 index 000000000..ecb16254f --- /dev/null +++ b/comfy/ldm/krea2/model.py @@ -0,0 +1,290 @@ +"""Krea 2 (K2) — single-stream MMDiT. + +Text tokens produced by a Qwen3-VL-4B 12-layer ``txtfusion`` adapter and patchified image tokens are +concatenated into one sequence and run through ``layers`` shared transformer blocks with +AdaLN-single modulation, GQA + per-head QK-norm + sigmoid-gated attention, SwiGLU MLP, and 3-axis RoPE. +""" + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +import comfy.model_management +import comfy.patcher_extension +import comfy.ldm.common_dit +from comfy.ldm.flux.layers import EmbedND, timestep_embedding +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.modules.attention import optimized_attention_masked + + +class RMSNorm(nn.Module): + """RMSNorm with the reference ``(1 + scale)`` weight convention (scale stored zero-centered).""" + + def __init__(self, features: int, eps: float = 1e-5, device=None, dtype=None, operations=None): + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.empty(features, device=device, dtype=dtype)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + weight = comfy.model_management.cast_to(self.scale, dtype=torch.float32, device=x.device) + 1.0 + return F.rms_norm(x.float(), (x.shape[-1],), weight=weight, eps=self.eps).to(dtype) + + +class QKNorm(nn.Module): + def __init__(self, dim: int, device=None, dtype=None, operations=None): + super().__init__() + self.qnorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations) + self.knorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations) + + def forward(self, q, k): + return self.qnorm(q), self.knorm(k) + + +class SwiGLU(nn.Module): + def __init__(self, features: int, multiplier: int, bias: bool = False, multiple: int = 128, + device=None, dtype=None, operations=None): + super().__init__() + mlpdim = int(2 * features / 3) * multiplier + mlpdim = multiple * ((mlpdim + multiple - 1) // multiple) + self.gate = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype) + self.up = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype) + self.down = operations.Linear(mlpdim, features, bias=bias, device=device, dtype=dtype) + + def forward(self, x): + return self.down(F.silu(self.gate(x)).mul_(self.up(x))) + + +class Attention(nn.Module): + def __init__(self, dim: int, heads: int, kvheads: Optional[int] = None, bias: bool = False, + device=None, dtype=None, operations=None): + super().__init__() + self.heads = heads + self.kvheads = kvheads if kvheads is not None else heads + self.headdim = dim // self.heads + self.wq = operations.Linear(dim, self.headdim * self.heads, bias=bias, device=device, dtype=dtype) + self.wk = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype) + self.wv = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype) + self.gate = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype) + self.qknorm = QKNorm(self.headdim, device=device, dtype=dtype, operations=operations) + self.wo = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype) + + def forward(self, x, freqs=None, mask=None, transformer_options={}): + q, k, v, gate = self.wq(x), self.wk(x), self.wv(x), self.gate(x) + q = rearrange(q, "B L (H D) -> B H L D", H=self.heads) + k = rearrange(k, "B L (H D) -> B H L D", H=self.kvheads) + v = rearrange(v, "B L (H D) -> B H L D", H=self.kvheads) + q, k = self.qknorm(q, k) + if freqs is not None: + q, k = apply_rope(q, k, freqs) + if self.kvheads != self.heads: + rep = self.heads // self.kvheads + k = k.repeat_interleave(rep, dim=1) + v = v.repeat_interleave(rep, dim=1) + out = optimized_attention_masked(q, k, v, self.heads, mask=mask, skip_reshape=True, + transformer_options=transformer_options) + return self.wo(out * F.sigmoid(gate)) + + +class SimpleModulation(nn.Module): + def __init__(self, dim: int, device=None, dtype=None, operations=None): + super().__init__() + self.lin = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype)) + + def forward(self, vec): + out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device).unsqueeze(0) + scale, shift = out.chunk(2, dim=1) + return scale, shift + + +class DoubleSharedModulation(nn.Module): + def __init__(self, dim: int, device=None, dtype=None, operations=None): + super().__init__() + self.lin = nn.Parameter(torch.empty(6 * dim, device=device, dtype=dtype)) + + def forward(self, vec): + out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device) + return out.chunk(6, dim=-1) + + +class TextFusionBlock(nn.Module): + def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None): + super().__init__() + self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations) + self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations) + self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations) + self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations) + + def forward(self, x, mask=None, transformer_options={}): + x = x + self.attn(self.prenorm(x), mask=mask, transformer_options=transformer_options) + x = x + self.mlp(self.postnorm(x)) + return x + + +class TextFusionTransformer(nn.Module): + def __init__(self, num_txt_layers, txt_dim, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None): + super().__init__() + self.layerwise_blocks = nn.ModuleList([ + TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations) + for _ in range(2) + ]) + self.projector = operations.Linear(num_txt_layers, 1, bias=False, device=device, dtype=dtype) + self.refiner_blocks = nn.ModuleList([ + TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations) + for _ in range(2) + ]) + + def forward(self, x, mask=None, transformer_options={}): + b, l, n, d = x.shape + x = x.reshape(b * l, n, d) + for block in self.layerwise_blocks: + x = block(x.contiguous(), mask=None, transformer_options=transformer_options) + x = rearrange(x, "(b l) n d -> b l d n", b=b, l=l) + x = self.projector(x).squeeze(-1) + for block in self.refiner_blocks: + x = block(x, mask=mask, transformer_options=transformer_options) + return x + + +class SingleStreamBlock(nn.Module): + def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None): + super().__init__() + self.mod = DoubleSharedModulation(features, device=device, dtype=dtype, operations=operations) + self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations) + self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations) + self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations) + self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations) + + def forward(self, x, vec, freqs, mask=None, transformer_options={}): + prescale, preshift, pregate, postscale, postshift, postgate = self.mod(vec) + x = x + pregate * self.attn((1 + prescale) * self.prenorm(x) + preshift, freqs, mask, transformer_options=transformer_options) + x = x + postgate * self.mlp((1 + postscale) * self.postnorm(x) + postshift) + return x + + +class LastLayer(nn.Module): + def __init__(self, features, patch, channels, device=None, dtype=None, operations=None): + super().__init__() + self.norm = RMSNorm(features, device=device, dtype=dtype, operations=operations) + self.linear = operations.Linear(features, patch * patch * channels, bias=True, device=device, dtype=dtype) + self.modulation = SimpleModulation(features, device=device, dtype=dtype, operations=operations) + + def forward(self, x, tvec): + scale, shift = self.modulation(tvec) + x = (1 + scale) * self.norm(x) + shift + return self.linear(x) + + +class SingleStreamDiT(nn.Module): + def __init__(self, features=6144, tdim=256, txtdim=2560, heads=48, kvheads=12, multiplier=4, + layers=28, patch=2, channels=16, bias=False, theta=1e3, txtlayers=12, + txtheads=20, txtkvheads=20, image_model=None, + device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + self.patch = patch + self.channels = channels + self.tdim = tdim + self.heads = heads + self.txtdim = txtdim + self.txtlayers = txtlayers + + headdim = features // heads + axes = [headdim - 12 * (headdim // 16), 6 * (headdim // 16), 6 * (headdim // 16)] + assert sum(axes) == headdim, f"axes {axes} sum != headdim {headdim}" + self.pe_embedder = EmbedND(dim=headdim, theta=int(theta), axes_dim=axes) + + self.first = operations.Linear(channels * patch ** 2, features, bias=True, device=device, dtype=dtype) + self.blocks = nn.ModuleList([ + SingleStreamBlock(features, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations) + for _ in range(layers) + ]) + self.tmlp = nn.Sequential( + operations.Linear(tdim, features, device=device, dtype=dtype), + nn.GELU(approximate="tanh"), + operations.Linear(features, features, device=device, dtype=dtype), + ) + self.txtfusion = TextFusionTransformer(txtlayers, txtdim, txtheads, multiplier, bias, txtkvheads, + device=device, dtype=dtype, operations=operations) + self.txtmlp = nn.Sequential( + RMSNorm(txtdim, device=device, dtype=dtype, operations=operations), + operations.Linear(txtdim, features, device=device, dtype=dtype), + nn.GELU(approximate="tanh"), + operations.Linear(features, features, device=device, dtype=dtype), + ) + self.last = LastLayer(features, patch, channels, device=device, dtype=dtype, operations=operations) + self.tproj = nn.Sequential( + nn.GELU(approximate="tanh"), + operations.Linear(features, features * 6, device=device, dtype=dtype), + ) + + def forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options), + ).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs) + + def _forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs): + temporal = x.ndim == 5 + if temporal: + b5, c5, t5, h5, w5 = x.shape + x = x.reshape(b5 * t5, c5, h5, w5) + bs, c, H_orig, W_orig = x.shape + patch = self.patch + # Pad the latent up to a multiple of patch (as Flux/Lumina/QwenImage do); crop back at the end. + x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch, patch)) + H, W = x.shape[-2], x.shape[-1] + h_, w_ = H // patch, W // patch + + # context arrives as (B, seq, txtlayers*txtdim); reshape to (B, txtlayers, seq, txtdim). + context = self._unpack_context(context) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch, pw=patch) + img = self.first(img) + + t = self.tmlp(timestep_embedding(timesteps, self.tdim).unsqueeze(1).to(img.dtype)) + tvec = self.tproj(t) + + context = self.txtfusion(context, mask=None, transformer_options=transformer_options) + context = self.txtmlp(context) + + txtlen, imglen = context.shape[1], img.shape[1] + combined = torch.cat((context, img), dim=1) + + # Position ids: text at 0, image at (0, h_idx, w_idx). + device = combined.device + txtpos = torch.zeros(bs, txtlen, 3, device=device, dtype=torch.float32) + imgids = torch.zeros(h_, w_, 3, device=device, dtype=torch.float32) + imgids[..., 1] = torch.arange(h_, device=device, dtype=torch.float32)[:, None] + imgids[..., 2] = torch.arange(w_, device=device, dtype=torch.float32)[None, :] + imgpos = imgids.reshape(1, h_ * w_, 3).repeat(bs, 1, 1) + pos = torch.cat((txtpos, imgpos), dim=1) + + freqs = self.pe_embedder(pos) + + for block in self.blocks: + combined = block(combined, tvec, freqs, None, transformer_options=transformer_options) + + final = self.last(combined, t) + out = final[:, txtlen:txtlen + imglen, :] + out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=h_, w=w_, ph=patch, pw=patch, c=self.channels) + out = out[:, :, :H_orig, :W_orig] # crop padding back off + if temporal: + out = out.reshape(b5, t5, self.channels, H_orig, W_orig).movedim(1, 2) + return out + + def _unpack_context(self, context): + # context: (B, seq, txtlayers*txtdim) -> (B, seq, txtlayers, txtdim). + b, seq, fused = context.shape + if fused != self.txtlayers * self.txtdim: + raise ValueError( + f"Krea2 expects conditioning with {self.txtlayers}x{self.txtdim}={self.txtlayers * self.txtdim} " + f"features (a {self.txtlayers}-layer Qwen3-VL stack) but got {fused}. " + f"Load the text encoder with CLIPLoader type 'krea2'." + ) + return context.reshape(b, seq, self.txtlayers, self.txtdim) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index e0a4a0f9b..9953b6679 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1085,7 +1085,7 @@ class LTXVModel(LTXBaseModel): ) grid_mask = None - if keyframe_idxs is not None: + if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0: additional_args.update({ "orig_patchified_shape": list(x.shape)}) denoise_mask = self.patchifier.patchify(denoise_mask)[0] grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] @@ -1330,7 +1330,7 @@ class LTXVModel(LTXBaseModel): x = x * (1 + scale) + shift x = self.proj_out(x) - if keyframe_idxs is not None: + if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0: grid_mask = kwargs["grid_mask"] orig_patchified_shape = kwargs["orig_patchified_shape"] full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) diff --git a/comfy/ldm/omnigen/omnigen2.py b/comfy/ldm/omnigen/omnigen2.py index e9ca5229d..b8da4cf39 100644 --- a/comfy/ldm/omnigen/omnigen2.py +++ b/comfy/ldm/omnigen/omnigen2.py @@ -22,7 +22,7 @@ def apply_rotary_emb(x, freqs_cis): def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return F.silu(x) * y + return F.silu(x, inplace=True).mul_(y) class TimestepEmbedding(nn.Module): diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 282408891..1c9782a38 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1665,7 +1665,7 @@ class SCAILWanModel(WanModel): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) - if ref_mask_latents is not None: # SCAIL-2 additive mask stream + if ref_mask_latents is not None: # SCAIL-2 additive mask stream (one identity mask frame per reference, then video) x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype) grid_sizes = x.shape[2:] transformer_options["grid_sizes"] = grid_sizes @@ -1728,22 +1728,25 @@ class SCAILWanModel(WanModel): # ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode, # which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset. + # reference_latent may stack several frames: the last is the primary reference adjacent to the video, the earlier frames are additional references. def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}): + ref_t_patches = 0 + if reference_latent is not None: + ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] + if ref_mask_flag is not None and not bool(ref_mask_flag): REF_ROPE_H = 120.0 POSE_ROPE_W = 120.0 - ref_t_patches = 0 - if reference_latent is not None: - ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] main_t_patches = t - ref_t_patches + video_t_start = max(ref_t_patches - 1, 0) parts = [] if ref_t_patches > 0: ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}} parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf)) if main_t_patches > 0: - parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options)) + parts.append(super().rope_encode(main_t_patches, h, w, t_start=video_t_start, device=device, dtype=dtype, transformer_options=transformer_options)) if pose_latents is not None: F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] @@ -1752,7 +1755,7 @@ class SCAILWanModel(WanModel): h_shift = (h_scale - 1) / 2 w_shift = (w_scale - 1) / 2 pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}} - parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf)) + parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=video_t_start, device=device, dtype=dtype, transformer_options=pose_tf)) return torch.cat(parts, dim=1) @@ -1761,10 +1764,6 @@ class SCAILWanModel(WanModel): if pose_latents is None: return main_freqs - ref_t_patches = 0 - if reference_latent is not None: - ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] - F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] # if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames diff --git a/comfy/lora.py b/comfy/lora.py index 2c8d0f0bf..427cf98aa 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -326,6 +326,17 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(key_lora)] = k key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format + if isinstance(model, comfy.model_base.Krea2): + diffusers_keys = comfy.utils.krea2_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = k[:-len(".weight")] + key_map["diffusion_model.{}".format(key_lora)] = to + key_map["transformer.{}".format(key_lora)] = to + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + key_map[key_lora] = to + if isinstance(model, comfy.model_base.Lumina2): diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: diff --git a/comfy/model_base.py b/comfy/model_base.py index 8b9f93ca2..f62788c29 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging import comfy.ldm.lightricks.av_model +import comfy.ldm.lightricks.symmetric_patchifier import comfy.context_windows from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC @@ -54,9 +55,11 @@ import comfy.ldm.pixeldit.model import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 +import comfy.ldm.boogu.model import comfy.ldm.qwen_image.model import comfy.ldm.joyimage.model import comfy.ldm.ideogram4.model +import comfy.ldm.krea2.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 @@ -1204,6 +1207,127 @@ class LTXAV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image + def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim): + result = [primary_indices] + if len(latent_shapes) < 2: + return result + + video_total = latent_shapes[0][dim] + + for i in range(1, len(latent_shapes)): + mod_total = latent_shapes[i][dim] + # Map each primary index to its proportional range of modality indices and + # concatenate in order. Preserves wrapped/strided geometry so the modality + # attends to the same temporal regions as the primary window. + mod_indices = [] + seen = set() + for v_idx in primary_indices: + a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1) + a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total) + if a_end <= a_start: + a_end = a_start + 1 + for a in range(a_start, a_end): + if a not in seen: + seen.add(a) + mod_indices.append(a) + result.append(mod_indices) + + return result + + @staticmethod + def _get_guide_entries(conds): + for cond_list in conds: + if cond_list is None: + continue + for cond_dict in cond_list: + model_conds = cond_dict.get('model_conds', {}) + entries = model_conds.get('guide_attention_entries') + if entries is not None and hasattr(entries, 'cond') and entries.cond: + return entries.cond + return None + + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + # Audio denoise mask — slice using audio modality window + if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows: + audio_window = window.modality_windows.get(1) + if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + sliced = audio_window.get_tensor(cond_value.cond, device, dim=2) + return cond_value._copy_with(sliced) + + # Video denoise mask — split into video + guide portions, slice each + if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + cond_tensor = cond_value.cond + guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim) + if guide_count > 0: + T_video = x_in.size(window.dim) + video_mask = cond_tensor.narrow(window.dim, 0, T_video) + guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count) + sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list) + suffix_indices = window.guide_frames_indices + if suffix_indices: + idx = tuple([slice(None)] * window.dim + [suffix_indices]) + sliced_guide = guide_mask[idx].to(device) + return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim)) + else: + return cond_value._copy_with(sliced_video) + + # Keyframe indices — regenerate pixel coords for window, select guide positions + if cond_key == "keyframe_idxs": + kf_local_pos = window.guide_kf_local_positions + if not kf_local_pos: + return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty + H, W = x_in.shape[3], x_in.shape[4] + window_len = len(window.index_list) + # account for causal_window_fix anchor in coord space size + anchor_idx = getattr(window, 'causal_anchor_index', None) + if anchor_idx is not None and anchor_idx >= 0: + window_len += 1 + patchifier = self.diffusion_model.patchifier + latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device) + scale_factors = self.diffusion_model.vae_scale_factors + pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords( + latent_coords, + scale_factors, + causal_fix=self.diffusion_model.causal_temporal_positioning) + tokens = [] + for pos in kf_local_pos: + tokens.extend(range(pos * H * W, (pos + 1) * H * W)) + pixel_coords = pixel_coords[:, :, tokens, :] + + # Adjust spatial end positions for dilated (downscaled) guides. + # Each guide entry may have a different downscale factor; expand the + # per-entry factor to cover all tokens belonging to that entry. + downscale_factors = window.guide_downscale_factors + overlap_info = window.guide_overlap_info + if downscale_factors: + per_token_factor = [] + for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors): + per_token_factor.extend([dsf] * (overlap_count * H * W)) + factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype) + spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor( + scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype, + ).view(1, -1, 1, 1) + pixel_coords[:, 1:, :, 1:] += spatial_end_offset + + B = cond_value.cond.shape[0] + if B > 1: + pixel_coords = pixel_coords.expand(B, -1, -1, -1) + return cond_value._copy_with(pixel_coords) + + # Guide attention entries — adjust per-guide counts based on window overlap + if cond_key == "guide_attention_entries": + overlap_info = window.guide_overlap_info + H, W = x_in.shape[3], x_in.shape[4] + new_entries = [] + for entry_idx, overlap_count in overlap_info: + e = cond_value.cond[entry_idx] + new_entries.append({**e, + "pre_filter_count": overlap_count * H * W, + "latent_shape": [overlap_count, H, W]}) + return cond_value._copy_with(new_entries) + + return None + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) @@ -1748,10 +1872,14 @@ class WAN21_SCAIL(WAN21): reference_latents = kwargs.get("reference_latents", None) if reference_latents is not None: - ref_latent = self.process_latent_in(reference_latents[-1]) - ref_mask = torch.ones_like(ref_latent[:, :4]) - ref_latent = torch.cat([ref_latent, ref_mask], dim=1) - out['reference_latent'] = comfy.conds.CONDRegular(ref_latent) + # SCAIL-2 multi-reference: reference_latents[0] is the primary ref, [1:] are additional + # references. Stack as [additional..., primary] so the primary stays adjacent to the video. + ordered = list(reference_latents[1:]) + list(reference_latents[:1]) + stacked = [] + for lat in ordered: + lat = self.process_latent_in(lat) + stacked.append(torch.cat([lat, torch.ones_like(lat[:, :4])], dim=1)) + out['reference_latent'] = comfy.conds.CONDRegular(torch.cat(stacked, dim=2)) pose_latents = kwargs.get("pose_video_latent", None) if pose_latents is not None: @@ -1793,6 +1921,7 @@ class WAN21_SCAIL2(WAN21_SCAIL): if driving_mask_28ch is not None: out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous()) + # ref_mask_28ch holds one identity mask per stacked reference frame (additional refs first, then the primary ref), followed by zeros over the video frames. ref_mask_28ch = kwargs.get("ref_mask_28ch", None) if ref_mask_28ch is not None: out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous()) @@ -1820,10 +1949,11 @@ class WAN21_SCAIL2(WAN21_SCAIL): # Return sliced view omitting retain_index_list return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=0) if cond_key == "ref_mask_latents" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): - # The ref mask is just a single frame padded with frames of zeros, so just grab the first frames for all windows + # The ref mask is N leading ref frames padded with frames of zeros, so just grab the first frames for all windows full_ref_mask = cond_value.cond video_frame_count = x_in.shape[2] - if full_ref_mask.shape[2] != video_frame_count + 1: + ref_frame_count = full_ref_mask.shape[2] - video_frame_count + if ref_frame_count < 1: return None window_length = len(window.index_list) @@ -1832,7 +1962,7 @@ class WAN21_SCAIL2(WAN21_SCAIL): if anchor_index is not None and anchor_index >= 0: window_length += 1 - window_ref_mask = full_ref_mask[:, :, :window_length + 1].to(device) + window_ref_mask = full_ref_mask[:, :, :window_length + ref_frame_count].to(device) return cond_value._copy_with(window_ref_mask) return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) @@ -2098,6 +2228,11 @@ class Omnigen2(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out +class Boogu(Omnigen2): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super(Omnigen2, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.boogu.model.BooguTransformer2DModel) + self.memory_usage_factor_conds = ("ref_latents",) + class QwenImage(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel) @@ -2265,6 +2400,17 @@ class Ideogram4(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out +class Krea2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.krea2.model.SingleStreamDiT) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out + class HunyuanImage21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ca43883a8..cd8df5a87 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -761,6 +761,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config + if '{}double_stream_layers.0.img_instruct_attn.processor.img_to_q.weight'.format(key_prefix) in state_dict_keys: # Boogu-Image (OmniGen2 derivative + dual-stream stage) + dit_config = {} + dit_config["image_model"] = "boogu" + dit_config["hidden_size"] = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[0] + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}single_stream_layers.'.format(key_prefix) + '{}.') + dit_config["num_double_stream_layers"] = count_blocks(state_dict_keys, '{}double_stream_layers.'.format(key_prefix) + '{}.') + dit_config["num_refiner_layers"] = count_blocks(state_dict_keys, '{}noise_refiner.'.format(key_prefix) + '{}.') + dit_config["instruction_feat_dim"] = state_dict['{}time_caption_embed.caption_embedder.0.weight'.format(key_prefix)].shape[0] + return dit_config + if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2 dit_config = {} dit_config["image_model"] = "omnigen2" @@ -845,6 +855,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') return dit_config + if '{}txtfusion.projector.weight'.format(key_prefix) in state_dict_keys: # Krea 2 (K2) + dit_config = {} + dit_config["image_model"] = "krea2" + head_dim = 128 + first_w = state_dict['{}first.weight'.format(key_prefix)] # (features, channels*patch^2) + dit_config["features"] = first_w.shape[0] + dit_config["channels"] = first_w.shape[1] // (2 * 2) # patch=2 + dit_config["patch"] = 2 + dit_config["layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') + dit_config["heads"] = state_dict['{}blocks.0.attn.wq.weight'.format(key_prefix)].shape[0] // head_dim + dit_config["kvheads"] = state_dict['{}blocks.0.attn.wk.weight'.format(key_prefix)].shape[0] // head_dim + dit_config["txtlayers"] = state_dict['{}txtfusion.projector.weight'.format(key_prefix)].shape[1] + dit_config["txtdim"] = state_dict['{}txtfusion.layerwise_blocks.0.prenorm.scale'.format(key_prefix)].shape[0] + return dit_config + if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 dit_config = {} model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0] diff --git a/comfy/ops.py b/comfy/ops.py index 3f088a962..69d32e254 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -256,7 +256,7 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w if (want_requant and len(fns) == 0 or update_weight): seed = comfy.utils.string_to_seed(s.seed_key) if isinstance(orig, QuantizedTensor): - y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) + y = orig.requantize_from_float(x, scale="recalculate", stochastic_rounding=seed) else: y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed) if want_requant and len(fns) == 0: @@ -1089,6 +1089,19 @@ def _load_quantized_module(module, super_load, state_dict, prefix, local_metadat if ts is None or bs is None: raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") scales = {"scale": ts, "block_scale": bs} + elif module.quant_format == "int8_tensorwise": + scale = pop_scale("weight_scale") + if scale is None: + raise ValueError(f"Missing INT8 weight scale for layer {layer_name}") + scales = {"scale": scale} + params_conf = layer_conf.get("params", {}) + if not isinstance(params_conf, dict): + params_conf = {} + if layer_conf.get("convrot", params_conf.get("convrot", False)): + scales["convrot"] = True + scales["convrot_groupsize"] = int( + layer_conf.get("convrot_groupsize", params_conf.get("convrot_groupsize", 256)) + ) else: raise ValueError(f"Unsupported quantization format: {module.quant_format}") @@ -1131,6 +1144,10 @@ def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extr quant_conf = {"format": module.quant_format} if getattr(module, '_full_precision_mm_config', False): quant_conf["full_precision_matrix_mult"] = True + params = getattr(module.weight, "_params", None) + if module.quant_format == "int8_tensorwise" and getattr(params, "convrot", False): + quant_conf["convrot"] = True + quant_conf["convrot_groupsize"] = getattr(params, "convrot_groupsize", 256) if extra_quant_conf: quant_conf.update(extra_quant_conf) sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8) @@ -1183,8 +1200,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) - def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant) + def forward_comfy_cast_weights( + self, + input, + compute_dtype=None, + want_requant=False, + weight_only_quant=False, + ): + if weight_only_quant: + weight, bias, offload_stream = cast_bias_weight( + self, + input=None, + dtype=self.weight.dtype, + device=input.device, + bias_dtype=input.dtype, + offloadable=True, + compute_dtype=compute_dtype, + want_requant=True, + ) + weight = weight.to(dtype=input.dtype) + else: + weight, bias, offload_stream = cast_bias_weight( + self, + input, + offloadable=True, + compute_dtype=compute_dtype, + want_requant=want_requant, + ) x = self._forward(input, weight, bias) uncast_bias_weight(self, weight, bias, offload_stream) return x @@ -1203,9 +1245,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec not getattr(self, 'comfy_force_cast_weights', False) and len(self.weight_function) == 0 and len(self.bias_function) == 0 ) + quantize_input = QUANT_ALGOS.get(getattr(self, 'quant_format', None), {}).get("quantize_input", True) # Training path: quantized forward with compute_dtype backward via autograd function - if (input.requires_grad and _use_quantized): + if (input.requires_grad and _use_quantized and quantize_input): weight, bias, offload_stream = cast_bias_weight( self, @@ -1227,7 +1270,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec return output # Inference path (unchanged) - if _use_quantized: + if _use_quantized and quantize_input: # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input @@ -1241,7 +1284,13 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec scale = comfy.model_management.cast_to_device(scale, input.device, None) input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) - output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor)) + weight_only_quant = _use_quantized and not quantize_input and isinstance(self.weight, QuantizedTensor) + output = self.forward_comfy_cast_weights( + input, + compute_dtype, + want_requant=isinstance(input, QuantizedTensor), + weight_only_quant=weight_only_quant, + ) # Reshape output back to 3D if input was 3D if reshaped_3d: @@ -1257,8 +1306,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: - # dtype is now implicit in the layout class - weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype) + weight = self.weight.requantize_from_float(weight, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype) else: weight = weight.to(self.weight.dtype) if return_weight: diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index b90bcfd25..44f25a97e 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -10,6 +10,7 @@ try: QuantizedLayout, TensorCoreFP8Layout as _CKFp8Layout, TensorCoreNVFP4Layout as _CKNvfp4Layout, + TensorWiseINT8Layout as _CKTensorWiseINT8Layout, register_layout_op, register_layout_class, get_layout_class, @@ -47,6 +48,9 @@ except ImportError as e: class _CKNvfp4Layout: pass + class _CKTensorWiseINT8Layout: + pass + def register_layout_class(name, cls): pass @@ -174,6 +178,7 @@ class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase): # Backward compatibility alias - default to E4M3 TensorCoreFP8Layout = TensorCoreFP8E4M3Layout +TensorWiseINT8Layout = _CKTensorWiseINT8Layout # ============================================================================== @@ -184,6 +189,7 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) +register_layout_class("TensorWiseINT8Layout", _CKTensorWiseINT8Layout) if _CK_MXFP8_AVAILABLE: register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) @@ -214,6 +220,13 @@ if _CK_MXFP8_AVAILABLE: "group_size": 32, } +QUANT_ALGOS["int8_tensorwise"] = { + "storage_t": torch.int8, + "parameters": {"weight_scale"}, + "comfy_tensor_layout": "TensorWiseINT8Layout", + "quantize_input": False, +} + # ============================================================================== # Re-exports for backward compatibility @@ -226,6 +239,7 @@ __all__ = [ "TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreNVFP4Layout", + "TensorWiseINT8Layout", "QUANT_ALGOS", "register_layout_op", ] diff --git a/comfy/sd.py b/comfy/sd.py index 3353eeb9d..943f249c2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -58,6 +58,7 @@ import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image +import comfy.text_encoders.krea2 import comfy.text_encoders.ideogram4 import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 @@ -68,6 +69,7 @@ import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.qwen35 import comfy.text_encoders.qwen3vl +import comfy.text_encoders.boogu import comfy.text_encoders.ernie import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo @@ -1302,7 +1304,9 @@ class CLIPType(Enum): LENS = 28 PIXELDIT = 29 IDEOGRAM4 = 30 - JOYIMAGE = 31 + BOOGU = 31 + KREA2 = 32 + JOYIMAGE = 33 @@ -1627,6 +1631,18 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) clip_target.clip = comfy.text_encoders.ideogram4.te_qwen3vl(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Qwen3VLTokenizer + elif clip_type == CLIPType.BOOGU and te_model == TEModel.QWEN3VL_8B: # Boogu-Image: full Qwen3-VL-8B, last hidden state, no-think template. + clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) + clip_target.clip = comfy.text_encoders.boogu.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.boogu.BooguTokenizer + elif clip_type == CLIPType.KREA2 and te_model == TEModel.QWEN3VL_4B: # Krea2: full Qwen3-VL-4B (12-layer tap for conditioning + multimodal generate). + clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) + clip_target.clip = comfy.text_encoders.krea2.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.krea2.Krea2Tokenizer + elif clip_type in (CLIPType.FLUX, CLIPType.FLUX2): # Flux2 Klein reuses the Qwen3-VL LM (3-layer tap -> 12288); visual unused. + klein_model_type = "qwen3_8b" if te_model == TEModel.QWEN3VL_8B else "qwen3_4b" + clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type=klein_model_type) + clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B if te_model == TEModel.QWEN3VL_8B else comfy.text_encoders.flux.KleinTokenizer else: clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model] diff --git a/comfy/supported_models.py b/comfy/supported_models.py index eb212f84b..2c9770134 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -25,6 +25,8 @@ import comfy.text_encoders.hunyuan_image import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image import comfy.text_encoders.ideogram4 +import comfy.text_encoders.boogu +import comfy.text_encoders.krea2 import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image @@ -1758,6 +1760,27 @@ class Omnigen2(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) +class Boogu(Omnigen2): + unet_config = { + "image_model": "boogu", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.16, + } + + memory_usage_factor = 2.15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Boogu(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.boogu.BooguTokenizer, comfy.text_encoders.boogu.te(**hunyuan_detect)) + class Ideogram4(supported_models_base.BASE): unet_config = { "image_model": "ideogram4", @@ -1796,6 +1819,35 @@ class Ideogram4(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.ideogram4.Ideogram4Tokenizer, comfy.text_encoders.ideogram4.te(**hunyuan_detect)) + +class Krea2(supported_models_base.BASE): + unet_config = { + "image_model": "krea2", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 1.15, + } + + memory_usage_factor = 2.2 + + latent_format = latent_formats.Wan21 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Krea2(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_4b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.krea2.Krea2Tokenizer, comfy.text_encoders.krea2.te(**hunyuan_detect)) + class QwenImage(supported_models_base.BASE): unet_config = { "image_model": "qwen_image", @@ -2339,9 +2391,11 @@ models = [ ACEStep, ACEStep15, Omnigen2, + Boogu, QwenImage, JoyImage, Ideogram4, + Krea2, Flux2, Lens, Kandinsky5Image, diff --git a/comfy/text_encoders/boogu.py b/comfy/text_encoders/boogu.py new file mode 100644 index 000000000..d9de92f10 --- /dev/null +++ b/comfy/text_encoders/boogu.py @@ -0,0 +1,58 @@ +"""Boogu-Image text encoder: full Qwen3-VL-8B, last hidden state (4096-dim). + +Boogu uses the final hidden state of Qwen3-VL as the per-token instruction feature +(num_instruction_feature_layers=1, reduce_type=mean -> just the last layer). +The model itself is the standard Qwen3-VL TE, only the chat template differs +(a fixed system prompt and no block). +""" + +import comfy.text_encoders.qwen3vl +from comfy import sd1_clip + + +# System prompts from the reference pipeline (pipeline_boogu.py). +# T2I (non-empty instruction, no image) uses the helpful-assistant prompt +# everything else (the CFG negative / "drop" condition, and any image case) uses the TI2I "describe" prompt. +BOOGU_T2I_SYSTEM = "You are a helpful assistant that generates high-quality images based on user instructions. The instructions are as follows." +BOOGU_DROP_SYSTEM = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate." + + +class BooguTokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_8b") + # apply_chat_template without add_generation_prompt + self.llama_template = "<|im_start|>system\n" + BOOGU_T2I_SYSTEM + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n" + self.llama_template_images = "<|im_start|>system\n" + BOOGU_DROP_SYSTEM + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n" + # Reference SYSTEM_PROMPT_DROP: used for the empty negative/uncond instruction. + self.llama_template_drop = "<|im_start|>system\n" + BOOGU_DROP_SYSTEM + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs): + if llama_template is None and len(images) == 0 and text.strip() == "": + llama_template = self.llama_template_drop + # Boogu conditions on the no-think template; thinking=True drops the empty block qwen3vl adds by default. + return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs) + + +class BooguQwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel): + def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}, model_type="qwen3vl_8b"): + super().__init__(device=device, dtype=dtype, attention_mask=attention_mask, model_options=model_options, model_type=model_type) + # apply the final RMSNorm to the tapped last layer + self.layer_norm_hidden_state = True + + +class BooguTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + clip_model = lambda **kw: BooguQwen3VLClipModel(**kw, model_type="qwen3vl_8b") + super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=clip_model, model_options=model_options) + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class BooguTEModel_(BooguTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return BooguTEModel_ diff --git a/comfy/text_encoders/krea2.py b/comfy/text_encoders/krea2.py new file mode 100644 index 000000000..408a03566 --- /dev/null +++ b/comfy/text_encoders/krea2.py @@ -0,0 +1,84 @@ +"""Krea 2 (K2) text encoder: Qwen3-VL-4B, 12-layer tap. + +K2 conditions on a stack of hidden states from 12 layers of Qwen3-VL-4B +(reference taps ``hidden_states[2,5,8,...,35]``), kept as a ``(B, 12, seq, 2560)`` tensor and +consumed by the DiT's internal ``txtfusion`` adapter. Comfy carries conditioning as a 3D tensor, +so the 12-layer stack is flattened to ``(B, seq, 12*2560)`` here and unpacked inside the model. +""" + +import numbers + +import torch + +import comfy.text_encoders.qwen3vl +from comfy import sd1_clip + +# tap k == hidden_states[k] (no offset). +KREA2_TAP_LAYERS = [2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35] + +# Identical system template to Qwen-Image; Krea2 strips the system+user-opening prefix. +KREA2_TEMPLATE = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + +class Krea2Tokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_4b") + self.llama_template = KREA2_TEMPLATE # conditioning template; image text-gen uses qwen3vl's default image template. + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs): + # Krea2 conditions on the no-think template; thinking=True drops the empty block qwen3vl adds. + return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs) + + +class Krea2Qwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel): + def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=KREA2_TAP_LAYERS, layer_idx=None, dtype=dtype, + attention_mask=attention_mask, model_options=model_options, model_type="qwen3vl_4b") + + +class Krea2TEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3vl_4b", clip_model=Krea2Qwen3VLClipModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) # out: (B, 12, seq, 2560) + tok_pairs = token_weight_pairs["qwen3vl_4b"][0] + + # Strip the system + user-opening prefix + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral): + if elem == 151644 and count_im_start < 2: + template_end = i + count_im_start += 1 + if out.shape[2] > (template_end + 3): + if tok_pairs[template_end + 1][0] == 872: # "user" + if tok_pairs[template_end + 2][0] == 198: # "\n" + template_end += 3 + + out = out[:, :, template_end:] + + b, n, seq, h = out.shape + # Flatten the 12-layer axis into the feature dim: (B, seq, 12*2560). Unpacked in the model. + out = out.permute(0, 2, 1, 3).reshape(b, seq, n * h) + + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, template_end:] + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") + + return out, pooled, extra + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class Krea2TEModel_(Krea2TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Krea2TEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index 09d783fff..61c2a22dd 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -818,6 +818,44 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""): return key_map +def krea2_to_diffusers(mmdit_config, output_prefix=""): + n_layers = mmdit_config.get("layers", 0) + n_txt_layerwise = 2 # TextFusionTransformer hardcodes 2 layerwise + 2 refiner blocks + n_txt_refiner = 2 + key_map = {} + + def add_block(prefix_to, prefix_from): + block_map = { + "attn.to_q": "attn.wq", "attn.to_k": "attn.wk", "attn.to_v": "attn.wv", + "attn.to_gate": "attn.gate", "attn.to_out.0": "attn.wo", + "attn.to_out": "attn.wo", # some tools drop the ".0" on to_out + "ff.gate": "mlp.gate", "ff.up": "mlp.up", "ff.down": "mlp.down", + } + for d, c in block_map.items(): + key_map["{}.{}.weight".format(prefix_to, d)] = "{}{}.{}.weight".format(output_prefix, prefix_from, c) + + for i in range(n_layers): + add_block("transformer_blocks.{}".format(i), "blocks.{}".format(i)) + for i in range(n_txt_layerwise): + add_block("text_fusion.layerwise_blocks.{}".format(i), "txtfusion.layerwise_blocks.{}".format(i)) + for i in range(n_txt_refiner): + add_block("text_fusion.refiner_blocks.{}".format(i), "txtfusion.refiner_blocks.{}".format(i)) + + MAP_BASIC = [ + ("img_in", "first"), + ("time_embed.linear_1", "tmlp.0"), + ("time_embed.linear_2", "tmlp.2"), + ("time_mod_proj", "tproj.1"), + ("txt_in.linear_1", "txtmlp.1"), + ("txt_in.linear_2", "txtmlp.3"), + ("text_fusion.projector", "txtfusion.projector"), + ("final_layer.linear", "last.linear"), + ] + for d, c in MAP_BASIC: + key_map["{}.weight".format(d)] = "{}{}.weight".format(output_prefix, c) + + return key_map + def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size) diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index adb5a3144..0f30608a9 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -25,6 +25,11 @@ CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = { "default": False, "description": "Show the sign-in button in the frontend even when not signed in", }, + "enable_telemetry": { + "type": "bool", + "default": False, + "description": "Signal the frontend that telemetry collection is enabled", + }, } diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 012fae3ac..58e49d8e2 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -891,6 +891,14 @@ class Tracks(ComfyTypeIO): track_visibility: torch.Tensor Type = TrackDict +@comfytype(io_type="DICT") +class Dict(ComfyTypeIO): + Type = dict + +@comfytype(io_type="ARRAY") +class Array(ComfyTypeIO): + Type = list + @comfytype(io_type="COMFY_MULTITYPED_V3") class MultiType: Type = Any @@ -1279,6 +1287,19 @@ class Color(ComfyTypeIO): def as_dict(self): return super().as_dict() + +@comfytype(io_type="COLORS") +class Colors(ComfyTypeIO): + Type = list[Color.Type] + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: list[str]=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = [] + + @comfytype(io_type="BOUNDING_BOX") class BoundingBox(ComfyTypeIO): class BoundingBoxDict(TypedDict): @@ -1326,6 +1347,20 @@ class Curve(ComfyTypeIO): return d +@comfytype(io_type="BOUNDING_BOXES") +class BoundingBoxes(ComfyTypeIO): + class BoundingBoxWithMetadata(BoundingBox.BoundingBoxDict): + metadata: dict + Type = list[BoundingBoxWithMetadata] + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: list[dict]=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = [] + + @comfytype(io_type="HISTOGRAM") class Histogram(ComfyTypeIO): """A histogram represented as a list of bin counts.""" @@ -2376,6 +2411,8 @@ __all__ = [ "AnyType", "MultiType", "Tracks", + "Dict", + "Array", "Color", # Dynamic Types "MatchType", @@ -2394,6 +2431,8 @@ __all__ = [ "PriceBadgeDepends", "PriceBadge", "BoundingBox", + "BoundingBoxes", + "Colors", "Curve", "Histogram", "Range", diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index 47f24586c..2d65d8645 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -163,15 +163,31 @@ class SeedanceVirtualLibraryCreateAssetRequest(BaseModel): asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.") -# Dollars per 1K tokens, keyed by (model_id, has_video_input). +# Dollars per 1K tokens, keyed by (model_id, has_video_input, resolution). SEEDANCE2_PRICE_PER_1K_TOKENS = { - ("dreamina-seedance-2-0-260128", False): 0.007, - ("dreamina-seedance-2-0-260128", True): 0.0043, - ("dreamina-seedance-2-0-fast-260128", False): 0.0056, - ("dreamina-seedance-2-0-fast-260128", True): 0.0033, + ("dreamina-seedance-2-0-260128", False, "480p"): 0.007, + ("dreamina-seedance-2-0-260128", True, "480p"): 0.0043, + ("dreamina-seedance-2-0-260128", False, "720p"): 0.007, + ("dreamina-seedance-2-0-260128", True, "720p"): 0.0043, + ("dreamina-seedance-2-0-260128", False, "1080p"): 0.0077, + ("dreamina-seedance-2-0-260128", True, "1080p"): 0.0047, + ("dreamina-seedance-2-0-260128", False, "4k"): 0.004, + ("dreamina-seedance-2-0-260128", True, "4k"): 0.0024, + ("dreamina-seedance-2-0-fast-260128", False, "480p"): 0.0056, + ("dreamina-seedance-2-0-fast-260128", True, "480p"): 0.0033, + ("dreamina-seedance-2-0-fast-260128", False, "720p"): 0.0056, + ("dreamina-seedance-2-0-fast-260128", True, "720p"): 0.0033, + ("dreamina-seedance-2-0-mini", False, "480p"): 0.0035, + ("dreamina-seedance-2-0-mini", True, "480p"): 0.0021, + ("dreamina-seedance-2-0-mini", False, "720p"): 0.0035, + ("dreamina-seedance-2-0-mini", True, "720p"): 0.0021, } +def seedance2_price_per_1k_tokens(model_id: str, has_video_input: bool, resolution: str) -> float | None: + return SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input, resolution)) + + RECOMMENDED_PRESETS = [ ("1024x1024 (1:1)", 1024, 1024), ("864x1152 (3:4)", 864, 1152), @@ -266,6 +282,10 @@ SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = { "480p": {"min": 409_600, "max": 927_408}, "720p": {"min": 409_600, "max": 927_408}, }, + "dreamina-seedance-2-0-mini": { + "480p": {"min": 409_600, "max": 927_408}, + "720p": {"min": 409_600, "max": 927_408}, + }, } # The time in this dictionary are given for 10 seconds duration. diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py index caaba8f36..7b2543270 100644 --- a/comfy_api_nodes/apis/gemini.py +++ b/comfy_api_nodes/apis/gemini.py @@ -121,6 +121,7 @@ class GeminiGenerationConfig(BaseModel): topK: int | None = Field(None, ge=1) topP: float | None = Field(None, ge=0.0, le=1.0) thinkingConfig: GeminiThinkingConfig | None = Field(None) + responseModalities: list[str] | None = Field(None) class GeminiImageOutputOptions(BaseModel): diff --git a/comfy_api_nodes/apis/kling.py b/comfy_api_nodes/apis/kling.py index fe0f97cb3..2c98c23b7 100644 --- a/comfy_api_nodes/apis/kling.py +++ b/comfy_api_nodes/apis/kling.py @@ -149,3 +149,59 @@ class MotionControlRequest(BaseModel): character_orientation: str = Field(...) mode: str = Field(..., description="'pro' or 'std'") model_name: str = Field(...) + + +class Kling3TurboSettings(BaseModel): + resolution: str = Field("720p", description="'720p' or '1080p'") + aspect_ratio: str | None = Field(None, description="'16:9'/'9:16'/'1:1'; text-to-video only") + duration: int = Field(5, description="3-15 second") + + +class Kling3TurboText2VideoRequest(BaseModel): + prompt: str = Field(..., description="<=3072 chars; may use multi-shot 'shot n, m, words; ...'") + settings: Kling3TurboSettings | None = Field(None) + + +class Kling3TurboContent(BaseModel): + type: str = Field(..., description="'prompt' or 'first_frame'") + text: str | None = Field(None, description="for type=prompt; <=2500 chars") + url: str | None = Field(None, description="for type=first_frame") + + +class Kling3TurboImage2VideoRequest(BaseModel): + contents: list[Kling3TurboContent] = Field(..., description="prompt + first_frame materials") + settings: Kling3TurboSettings | None = Field(None) + + +class Kling3TurboCreateData(BaseModel): + id: str | None = Field(None, description="Task ID") + status: str | None = Field(None) + message: str | None = Field(None) + + +class Kling3TurboCreateResponse(BaseModel): + code: int | None = Field(None) + message: str | None = Field(None) + request_id: str | None = Field(None) + data: Kling3TurboCreateData | None = Field(None) + + +class Kling3TurboOutput(BaseModel): + type: str | None = Field(None, description="'video', 'image', 'audio', ...") + id: str | None = Field(None) + url: str | None = Field(None) + duration: str | None = Field(None) + + +class Kling3TurboTaskData(BaseModel): + id: str | None = Field(None) + status: str | None = Field(None, description="submitted | processing | succeeded | failed") + message: str | None = Field(None) + outputs: list[Kling3TurboOutput] | None = Field(None) + + +class Kling3TurboQueryResponse(BaseModel): + code: int | None = Field(None) + message: str | None = Field(None) + request_id: str | None = Field(None) + data: list[Kling3TurboTaskData] | None = Field(None) diff --git a/comfy_api_nodes/apis/luma.py b/comfy_api_nodes/apis/luma.py index 8c6db2022..2465c3b37 100644 --- a/comfy_api_nodes/apis/luma.py +++ b/comfy_api_nodes/apis/luma.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, confloat class LumaIO: LUMA_REF = "LUMA_REF" LUMA_CONCEPTS = "LUMA_CONCEPTS" + LUMA_RAY32_KEYFRAME = "LUMA_RAY32_KEYFRAME" class LumaReference: @@ -20,13 +21,14 @@ class LumaReference: def create_api_model(self, download_url: str): return LumaImageRef(url=download_url, weight=self.weight) + class LumaReferenceChain: - def __init__(self, first_ref: LumaReference=None): + def __init__(self, first_ref: LumaReference = None): self.refs: list[LumaReference] = [] if first_ref: self.refs.append(first_ref) - def add(self, luma_ref: LumaReference=None): + def add(self, luma_ref: LumaReference = None): self.refs.append(luma_ref) def create_api_model(self, download_urls: list[str], max_refs=4): @@ -124,7 +126,7 @@ def get_luma_concepts(include_none=False): "pull_out", "aerial", "crane_up", - "eye_level" + "eye_level", ] @@ -162,8 +164,8 @@ class LumaVideoModelOutputDuration(str, Enum): class LumaGenerationType(str, Enum): - video = 'video' - image = 'image' + video = "video" + image = "image" class LumaState(str, Enum): @@ -174,86 +176,109 @@ class LumaState(str, Enum): class LumaAssets(BaseModel): - video: Optional[str] = Field(None, description='The URL of the video') - image: Optional[str] = Field(None, description='The URL of the image') - progress_video: Optional[str] = Field(None, description='The URL of the progress video') + video: Optional[str] = Field(None, description="The URL of the video") + image: Optional[str] = Field(None, description="The URL of the image") + progress_video: Optional[str] = Field(None, description="The URL of the progress video") class LumaImageRef(BaseModel): """Used for image gen""" - url: str = Field(..., description='The URL of the image reference') - weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') + + url: str = Field(..., description="The URL of the image reference") + weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference") class LumaImageReference(BaseModel): """Used for video gen""" - type: Optional[str] = Field('image', description='Input type, defaults to image') - url: str = Field(..., description='The URL of the image') + + type: Optional[str] = Field("image", description="Input type, defaults to image") + url: str = Field(..., description="The URL of the image") class LumaModifyImageRef(BaseModel): - url: str = Field(..., description='The URL of the image reference') - weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') + url: str = Field(..., description="The URL of the image reference") + weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference") class LumaCharacterRef(BaseModel): - identity0: LumaImageIdentity = Field(..., description='The image identity object') + identity0: LumaImageIdentity = Field(..., description="The image identity object") class LumaImageIdentity(BaseModel): - images: list[str] = Field(..., description='The URLs of the image identity') + images: list[str] = Field(..., description="The URLs of the image identity") class LumaGenerationReference(BaseModel): - type: str = Field('generation', description='Input type, defaults to generation') - id: str = Field(..., description='The ID of the generation') + type: str = Field("generation", description="Input type, defaults to generation") + id: str = Field(..., description="The ID of the generation") class LumaKeyframes(BaseModel): - frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='') - frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='') + frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="") + frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="") class LumaConceptObject(BaseModel): - key: str = Field(..., description='Camera Concept name') + key: str = Field(..., description="Camera Concept name") class LumaImageGenerationRequest(BaseModel): - prompt: str = Field(..., description='The prompt of the generation') - model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation') - aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation') - image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects') - style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects') - character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object') - modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object') + prompt: str = Field(..., description="The prompt of the generation") + model: LumaImageModel = Field(LumaImageModel.photon_1, description="The image model used for the generation") + aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9) + image_ref: Optional[list[LumaImageRef]] = Field(None, description="List of image reference objects") + style_ref: Optional[list[LumaImageRef]] = Field(None, description="List of style reference objects") + character_ref: Optional[LumaCharacterRef] = Field(None, description="The image identity object") + modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description="The modify image reference object") class LumaGenerationRequest(BaseModel): - prompt: str = Field(..., description='The prompt of the generation') - model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation') - duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation') - aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation') - resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation') - loop: Optional[bool] = Field(None, description='Whether to loop the video') - keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation') - concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation') + prompt: str = Field(..., description="The prompt of the generation") + model: LumaVideoModel = Field(LumaVideoModel.ray_2, description="The video model used for the generation") + duration: Optional[LumaVideoModelOutputDuration] = Field(None, description="The duration of the generation") + aspect_ratio: Optional[LumaAspectRatio] = Field(None, description="The aspect ratio of the generation") + resolution: Optional[LumaVideoOutputResolution] = Field(None, description="The resolution of the generation") + loop: Optional[bool] = Field(None, description="Whether to loop the video") + keyframes: Optional[LumaKeyframes] = Field(None, description="The keyframes of the generation") + concepts: Optional[list[LumaConceptObject]] = Field(None, description="Camera Concepts to apply to generation") class LumaGeneration(BaseModel): - id: str = Field(..., description='The ID of the generation') - generation_type: LumaGenerationType = Field(..., description='Generation type, image or video') - state: LumaState = Field(..., description='The state of the generation') - failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation') - created_at: str = Field(..., description='The date and time when the generation was created') - assets: Optional[LumaAssets] = Field(None, description='The assets of the generation') - model: str = Field(..., description='The model used for the generation') - request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation") + id: str = Field(..., description="The ID of the generation") + generation_type: LumaGenerationType = Field(..., description="Generation type, image or video") + state: LumaState = Field(..., description="The state of the generation") + failure_reason: Optional[str] = Field(None, description="The reason for the state of the generation") + created_at: str = Field(..., description="The date and time when the generation was created") + assets: Optional[LumaAssets] = Field(None, description="The assets of the generation") + model: str = Field(..., description="The model used for the generation") + request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(...) class Luma2ImageRef(BaseModel): url: str | None = None data: str | None = None media_type: str | None = None + generation_id: str | None = Field(None, description="reference a prior generation (extend / source reuse)") + + +class Luma2VideoEdit(BaseModel): + """Edit controls for Ray 3.2 ``video_edit`` generations.""" + + auto_controls: bool | None = Field(None, description="derive a conditioning schedule from the source (recommended)") + strength: str | None = Field(None, description="'adhere_1' .. 'reimagine_3'; constrained by IO.Combo") + + +class Luma2VideoOptions(BaseModel): + """Ray 3.2 ``video`` output settings (text / image / keyframe / edit / extend).""" + + resolution: str | None = Field(None, description="360p | 540p | 720p | 1080p") + duration: str | None = Field(None, description="5s | 10s") + loop: bool | None = Field(None) + start_frame: Luma2ImageRef | None = Field(None) + end_frame: Luma2ImageRef | None = Field(None) + keyframes: list[Luma2ImageRef] | None = Field(None) + keyframe_indexes: list[int] | None = Field(None) + edit: Luma2VideoEdit | None = Field(None) class Luma2GenerationRequest(BaseModel): @@ -266,6 +291,7 @@ class Luma2GenerationRequest(BaseModel): web_search: bool | None = None image_ref: list[Luma2ImageRef] | None = None source: Luma2ImageRef | None = None + video: Luma2VideoOptions | None = Field(None) class Luma2Generation(BaseModel): @@ -277,3 +303,31 @@ class Luma2Generation(BaseModel): output: list[LumaImageReference] | None = None failure_reason: str | None = None failure_code: str | None = None + + +# --- Ray 3.2 multi-keyframe chain --- + +LUMA_KEYFRAME_MODE_FRACTION = "fraction" # value in [0.0, 1.0] of the output video duration +LUMA_KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the output + + +class LumaRay32KeyframeItem: + """One guide image anchored at a position on the Ray 3.2 output timeline.""" + + def __init__(self, image: torch.Tensor, mode: str, value: float): + self.image = image + self.mode = mode # LUMA_KEYFRAME_MODE_FRACTION | LUMA_KEYFRAME_MODE_SECONDS + self.value = value + + +class LumaRay32KeyframeChain: + def __init__(self): + self.items: list[LumaRay32KeyframeItem] = [] + + def add(self, item: LumaRay32KeyframeItem) -> None: + self.items.append(item) + + def clone(self) -> "LumaRay32KeyframeChain": + c = LumaRay32KeyframeChain() + c.items = list(self.items) + return c diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index c30ddc446..f22415abd 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -15,7 +15,6 @@ from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS_SEEDREAM_4_0, RECOMMENDED_PRESETS_SEEDREAM_4_5, RECOMMENDED_PRESETS_SEEDREAM_5_LITE, - SEEDANCE2_PRICE_PER_1K_TOKENS, SEEDANCE2_REF_VIDEO_PIXEL_LIMITS, VIDEO_TASKS_EXECUTION_TIME, GetAssetResponse, @@ -40,6 +39,7 @@ from comfy_api_nodes.apis.bytedance import ( TaskVideoContentUrl, Text2ImageTaskCreationRequest, Text2VideoTaskCreationRequest, + seedance2_price_per_1k_tokens, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -89,6 +89,7 @@ BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/cont SEEDANCE_MODELS = { "Seedance 2.0": "dreamina-seedance-2-0-260128", "Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128", + "Seedance 2.0 Mini": "dreamina-seedance-2-0-mini", } DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"} @@ -141,7 +142,7 @@ SEEDANCE2_RATIO_WH = { "9:16": (9, 16), "21:9": (21, 9), } -SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080} +SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080, "4k": 2160} def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]: @@ -377,9 +378,9 @@ async def _seedance_virtual_library_upload_video_asset( return f"asset://{create_resp.asset_id}" -def _seedance2_price_extractor(model_id: str, has_video_input: bool): +def _seedance2_price_extractor(model_id: str, has_video_input: bool, resolution: str): """Returns a price_extractor closure for Seedance 2.0 poll_op.""" - rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input)) + rate = seedance2_price_per_1k_tokens(model_id, has_video_input, resolution) if rate is None: return None @@ -1621,10 +1622,12 @@ class ByteDance2TextToVideoNode(IO.ComfyNode): IO.DynamicCombo.Input( "model", options=[ - IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])), + IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p", "4k"])), IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])), + IO.DynamicCombo.Option("Seedance 2.0 Mini", _seedance2_text_inputs(["480p", "720p"])), ], - tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", + tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; " + "Mini for the fastest, lowest-cost generation.", ), IO.Int.Input( "seed", @@ -1660,11 +1663,16 @@ class ByteDance2TextToVideoNode(IO.ComfyNode): $rate480 := 10044; $rate720 := 21600; $rate1080 := 48800; + $rate4k := 195200; $m := widgets.model; - $pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001; $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $rate := $res = "1080p" ? $rate1080 : + $pricePer1K := $res = "4k" ? 0.00572 : + $res = "1080p" ? 0.011011 : + $contains($m, "mini") ? 0.005005 : + $contains($m, "fast") ? 0.008008 : 0.01001; + $rate := $res = "4k" ? $rate4k : + $res = "1080p" ? $rate1080 : $res = "720p" ? $rate720 : $rate480; $cost := $dur * $rate * $pricePer1K / 1000; @@ -1703,7 +1711,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode): ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"), response_model=TaskStatusResponse, status_extractor=lambda r: r.status, - price_extractor=_seedance2_price_extractor(model_id, has_video_input=False), + price_extractor=_seedance2_price_extractor(model_id, has_video_input=False, resolution=model["resolution"]), poll_interval=9, ) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) @@ -1724,14 +1732,19 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): options=[ IO.DynamicCombo.Option( "Seedance 2.0", - _seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"), + _seedance2_text_inputs(["480p", "720p", "1080p", "4k"], default_ratio="adaptive"), ), IO.DynamicCombo.Option( "Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"), ), + IO.DynamicCombo.Option( + "Seedance 2.0 Mini", + _seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"), + ), ], - tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", + tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; " + "Mini for the fastest, lowest-cost generation.", ), IO.Image.Input( "first_frame", @@ -1791,11 +1804,16 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): $rate480 := 10044; $rate720 := 21600; $rate1080 := 48800; + $rate4k := 195200; $m := widgets.model; - $pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001; $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $rate := $res = "1080p" ? $rate1080 : + $pricePer1K := $res = "4k" ? 0.00572 : + $res = "1080p" ? 0.011011 : + $contains($m, "mini") ? 0.005005 : + $contains($m, "fast") ? 0.008008 : 0.01001; + $rate := $res = "4k" ? $rate4k : + $res = "1080p" ? $rate1080 : $res = "720p" ? $rate720 : $rate480; $cost := $dur * $rate * $pricePer1K / 1000; @@ -1913,7 +1931,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"), response_model=TaskStatusResponse, status_extractor=lambda r: r.status, - price_extractor=_seedance2_price_extractor(model_id, has_video_input=False), + price_extractor=_seedance2_price_extractor(model_id, has_video_input=False, resolution=model["resolution"]), poll_interval=9, ) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) @@ -2010,14 +2028,19 @@ class ByteDance2ReferenceNode(IO.ComfyNode): options=[ IO.DynamicCombo.Option( "Seedance 2.0", - _seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"), + _seedance2_reference_inputs(["480p", "720p", "1080p", "4k"], default_ratio="adaptive"), ), IO.DynamicCombo.Option( "Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"), ), + IO.DynamicCombo.Option( + "Seedance 2.0 Mini", + _seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"), + ), ], - tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", + tooltip="Seedance 2.0 for maximum quality; Fast for speed optimization; " + "Mini for the fastest, lowest-cost generation.", ), IO.Int.Input( "seed", @@ -2056,13 +2079,21 @@ class ByteDance2ReferenceNode(IO.ComfyNode): $rate480 := 10044; $rate720 := 21600; $rate1080 := 48800; + $rate4k := 195200; $m := widgets.model; $hasVideo := $lookup(inputGroups, "model.reference_videos") > 0; - $noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001; - $videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149; $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $rate := $res = "1080p" ? $rate1080 : + $noVideoPricePer1K := $res = "4k" ? 0.00572 : + $res = "1080p" ? 0.011011 : + $contains($m, "mini") ? 0.005005 : + $contains($m, "fast") ? 0.008008 : 0.01001; + $videoPricePer1K := $res = "4k" ? 0.003432 : + $res = "1080p" ? 0.006721 : + $contains($m, "mini") ? 0.003003 : + $contains($m, "fast") ? 0.004719 : 0.006149; + $rate := $res = "4k" ? $rate4k : + $res = "1080p" ? $rate1080 : $res = "720p" ? $rate720 : $rate480; $noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000; @@ -2258,7 +2289,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode): ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"), response_model=TaskStatusResponse, status_extractor=lambda r: r.status, - price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input), + price_extractor=_seedance2_price_extractor( + model_id, has_video_input=has_video_input, resolution=model["resolution"] + ), poll_interval=9, ) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 3d4be6065..aa992802d 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -5,7 +5,6 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer import base64 import os -from enum import Enum from fnmatch import fnmatch from io import BytesIO from typing import Any, Literal @@ -14,7 +13,7 @@ import torch from typing_extensions import override import folder_paths -from comfy_api.latest import IO, ComfyExtension, Input, Types +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types from comfy_api_nodes.apis.gemini import ( GeminiContent, GeminiFileData, @@ -38,6 +37,7 @@ from comfy_api_nodes.util import ( audio_to_base64_string, bytesio_to_image_tensor, download_url_to_image_tensor, + download_url_to_video_output, get_number_of_images, sync_op, tensor_to_base64_string, @@ -46,6 +46,7 @@ from comfy_api_nodes.util import ( upload_images_to_comfyapi, upload_video_to_comfyapi, validate_string, + validate_video_duration, video_to_base64_string, ) @@ -78,15 +79,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( ) -class GeminiImageModel(str, Enum): - """ - Gemini Image Model Names allowed by comfy-api - """ - - gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview" - gemini_2_5_flash_image = "gemini-2.5-flash-image" - - async def create_image_parts( cls: type[IO.ComfyNode], images: Input.Image | list[Input.Image], @@ -239,25 +231,38 @@ async def get_image_from_response(response: GeminiGenerateContentResponse, thoug return torch.cat(image_tensors, dim=0) +async def get_video_from_response( + response: GeminiGenerateContentResponse, cls: type[IO.ComfyNode] | None = None +) -> InputImpl.VideoFromFile: + parts = get_parts_by_type(response, "video/*") + for part in parts: + if part.inlineData and part.inlineData.data: + return InputImpl.VideoFromFile(BytesIO(base64.b64decode(part.inlineData.data))) + if part.fileData and part.fileData.fileUri: + return await download_url_to_video_output(part.fileData.fileUri, cls=cls) + model_message = get_text_from_response(response).strip() + if model_message: + raise ValueError(f"Gemini did not generate a video. Model response: {model_message}") + raise ValueError( + "Gemini did not generate a video. Try rephrasing your prompt, " + "shortening the requested duration, or reducing the number of input images/videos." + ) + + def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None: if not response.modelVersion: return None # Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing - if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"): + output_video_tokens_price = 0.0 + if response.modelVersion == "gemini-2.5-pro": input_tokens_price = 1.25 output_text_tokens_price = 10.0 output_image_tokens_price = 0.0 - elif response.modelVersion in ( - "gemini-2.5-flash-preview-04-17", - "gemini-2.5-flash", - ): + elif response.modelVersion == "gemini-2.5-flash": input_tokens_price = 0.30 output_text_tokens_price = 2.50 output_image_tokens_price = 0.0 - elif response.modelVersion in ( - "gemini-2.5-flash-image-preview", - "gemini-2.5-flash-image", - ): + elif response.modelVersion == "gemini-2.5-flash-image": input_tokens_price = 0.30 output_text_tokens_price = 2.50 output_image_tokens_price = 30.0 @@ -265,18 +270,27 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N input_tokens_price = 2 output_text_tokens_price = 12.0 output_image_tokens_price = 0.0 - elif response.modelVersion == "gemini-3.1-flash-lite-preview": + elif response.modelVersion in ("gemini-3.1-flash-lite-preview", "gemini-3.1-flash-lite"): input_tokens_price = 0.25 output_text_tokens_price = 1.50 output_image_tokens_price = 0.0 - elif response.modelVersion == "gemini-3-pro-image-preview": + elif response.modelVersion in ("gemini-3-pro-image-preview", "gemini-3-pro-image"): input_tokens_price = 2 output_text_tokens_price = 12.0 output_image_tokens_price = 120.0 - elif response.modelVersion == "gemini-3.1-flash-image-preview": + elif response.modelVersion in ("gemini-3.1-flash-image-preview", "gemini-3.1-flash-image"): input_tokens_price = 0.5 output_text_tokens_price = 3.0 output_image_tokens_price = 60.0 + elif response.modelVersion == "gemini-3.1-flash-lite-image": + input_tokens_price = 0.25 + output_text_tokens_price = 1.50 + output_image_tokens_price = 30.0 + elif response.modelVersion == "gemini-omni-flash-preview": + input_tokens_price = 2.145 + output_text_tokens_price = 12.87 + output_image_tokens_price = 0.0 + output_video_tokens_price = 25.025 else: return None final_price = response.usageMetadata.promptTokenCount * input_tokens_price @@ -284,6 +298,8 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N for i in response.usageMetadata.candidatesTokensDetails: if i.modality == Modality.IMAGE: final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models + elif i.modality == Modality.VIDEO: + final_price += output_video_tokens_price * i.tokenCount # for Omni Flash else: final_price += output_text_tokens_price * i.tokenCount if response.usageMetadata.thoughtsTokenCount: @@ -455,8 +471,6 @@ class GeminiNode(IO.ComfyNode): IO.Combo.Input( "model", options=[ - "gemini-2.5-pro-preview-05-06", - "gemini-2.5-flash-preview-04-17", "gemini-2.5-pro", "gemini-2.5-flash", "gemini-3-pro-preview", @@ -904,8 +918,7 @@ class GeminiImage(IO.ComfyNode): ), IO.Combo.Input( "model", - options=GeminiImageModel, - default=GeminiImageModel.gemini_2_5_flash_image, + options=["gemini-2.5-flash-image"], tooltip="The Gemini model to use for generating responses.", ), IO.Int.Input( @@ -1321,7 +1334,7 @@ class GeminiNanoBanana2(IO.ComfyNode): ) -def _nano_banana_2_v2_model_inputs(): +def _nano_banana_2_v2_model_inputs(resolutions: list[str]): return [ IO.Combo.Input( "aspect_ratio", @@ -1348,8 +1361,8 @@ def _nano_banana_2_v2_model_inputs(): ), IO.Combo.Input( "resolution", - options=["1K", "2K", "4K"], - tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.", + options=resolutions, + tooltip="Target output resolution.", ), IO.Combo.Input( "thinking_level", @@ -1395,7 +1408,11 @@ class GeminiNanoBanana2V2(IO.ComfyNode): options=[ IO.DynamicCombo.Option( "Nano Banana 2 (Gemini 3.1 Flash Image)", - _nano_banana_2_v2_model_inputs(), + _nano_banana_2_v2_model_inputs(resolutions=["1K", "2K", "4K"]), + ), + IO.DynamicCombo.Option( + "Nano Banana 2 Lite", + _nano_banana_2_v2_model_inputs(resolutions=["1K"]), ), ], ), @@ -1464,9 +1481,13 @@ class GeminiNanoBanana2V2(IO.ComfyNode): depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]), expr=""" ( - $r := $lookup(widgets, "model.resolution"); - $prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154}; - {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} + $contains(widgets.model, "lite") + ? {"type":"usd","usd": 0.034, "format":{"suffix":"/Image","approximate":true}} + : ( + $r := $lookup(widgets, "model.resolution"); + $prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154}; + {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} + ) ) """, ), @@ -1487,6 +1508,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode): model_choice = model["model"] if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)": model_id = "gemini-3.1-flash-image-preview" + elif model_choice == "Nano Banana 2 Lite": + model_id = "gemini-3.1-flash-lite-image" else: model_id = model_choice @@ -1536,6 +1559,149 @@ class GeminiNanoBanana2V2(IO.ComfyNode): ) +OMNI_MAX_IMAGES = 14 +OMNI_MAX_VIDEOS = 3 + +OMNI_MODELS: dict[str, str] = { + "Omni Flash": "gemini-omni-flash-preview", +} + + +def _omni_flash_inputs() -> list[Input]: + """Per-model inputs for the Omni video DynamicCombo (prompt + reference media + sampling).""" + return [ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Describe the video to generate. Specify the length and aspect ratio directly in the " + 'prompt, e.g. "a 6-second clip in 16:9". Length may be 3-10 seconds; the aspect ratio must be ' + "16:9 (landscape) or 9:16 (portrait). The output is 720p, 24 FPS, with audio.", + ), + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("image"), + names=[f"image_{i}" for i in range(1, OMNI_MAX_IMAGES + 1)], + min=0, + ), + tooltip=f"Optional reference image(s) to guide or animate the video. Up to {OMNI_MAX_IMAGES} images.", + ), + IO.Autogrow.Input( + "videos", + template=IO.Autogrow.TemplateNames( + IO.Video.Input("video"), + names=[f"video_{i}" for i in range(1, OMNI_MAX_VIDEOS + 1)], + min=0, + ), + tooltip=f"Optional reference video(s) to guide or edit. Up to {OMNI_MAX_VIDEOS} videos, " + f"each up to 10 seconds long.", + ), + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.01, + tooltip="Controls randomness. Lower is more focused/deterministic, higher is more varied.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=0.95, + min=0.0, + max=1.0, + step=0.01, + tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.", + advanced=True, + ), + ] + + +class GeminiVideoOmni(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiVideoOmni", + display_name="Google Gemini Omni (Video)", + category="partner/video/Gemini", + essentials_category="Video Generation", + description="Generate a video with audio from a text prompt using Google's Gemini Omni Flash model. " + "Optionally provide reference images and/or videos to guide or edit the result. Describe the desired " + "length (3-10s) and aspect ratio (16:9 or 9:16) directly in the prompt.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option("Omni Flash", _omni_flash_inputs()), + ], + tooltip="The Gemini video model used to generate the video.", + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr='{"type":"usd","usd":0.146,"format":{"suffix":"/second","approximate":true}}' + ), + ) + + @classmethod + async def execute(cls, model: dict, seed: int) -> IO.NodeOutput: + prompt = model.get("prompt") or "" + validate_string(prompt, strip_whitespace=True, min_length=1) + model_id = OMNI_MODELS[model["model"]] + + images = [t for t in (model.get("images") or {}).values() if t is not None] + videos = [v for v in (model.get("videos") or {}).values() if v is not None] + if sum(get_number_of_images(t) for t in images) > OMNI_MAX_IMAGES: + raise ValueError(f"The current maximum number of supported images is {OMNI_MAX_IMAGES}.") + if len(videos) > OMNI_MAX_VIDEOS: + raise ValueError(f"The current maximum number of supported videos is {OMNI_MAX_VIDEOS}.") + for video in videos: + validate_video_duration(video, max_duration=10) + + parts: list[GeminiPart] = [] + if images or videos: + parts.extend(await build_gemini_media_parts(cls, images, [], videos)) + parts.append(GeminiPart(text=prompt)) + response = await sync_op( + cls, + ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"), + data=GeminiGenerateContentRequest( + contents=[GeminiContent(role=GeminiRole.user, parts=parts)], + generationConfig=GeminiGenerationConfig( + responseModalities=["TEXT", "VIDEO"], + temperature=model.get("temperature", 1.0), + topP=model.get("top_p", 0.95), + ), + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + return IO.NodeOutput( + await get_video_from_response(response, cls=cls), + get_text_from_response(response), + ) + + class GeminiExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -1546,6 +1712,7 @@ class GeminiExtension(ComfyExtension): GeminiImage2, GeminiNanoBanana2, GeminiNanoBanana2V2, + GeminiVideoOmni, GeminiInputFiles, ] diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index 2ae529813..dc484536e 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -30,7 +30,7 @@ from comfy_api_nodes.util import ( _GROK_VIDEO_MODEL_API_IDS = { - "grok-imagine-video-1.5": "grok-imagine-video-1.5-preview", + "grok-imagine-video-1.5": "grok-imagine-video-1.5", } @@ -521,8 +521,8 @@ class GrokVideoNode(IO.ComfyNode): ), IO.Combo.Input( "resolution", - options=["480p", "720p"], - tooltip="The resolution of the output video.", + options=["480p", "720p", "1080p"], + tooltip="The resolution of the output video. 1080p is only available for grok-imagine-video-1.5.", ), IO.Combo.Input( "aspect_ratio", @@ -570,11 +570,12 @@ class GrokVideoNode(IO.ComfyNode): ( $is15 := $contains(widgets.model, "1.5"); $rate := $is15 - ? (widgets.resolution = "720p" ? 0.2002 : 0.1144) + ? (widgets.resolution = "1080p" ? 0.25 : (widgets.resolution = "720p" ? 0.14 : 0.08)) : (widgets.resolution = "720p" ? 0.07 : 0.05); - $imgCost := $is15 ? 0.0143 : 0.002; + $imgCost := $is15 ? 0.01 : 0.002; $base := $rate * widgets.duration; - {"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base} + $total := inputs.image.connected ? $base + $imgCost : $base; + {"type":"usd","usd": $is15 ? $total * 1.43 : $total} ) """, ), @@ -593,6 +594,8 @@ class GrokVideoNode(IO.ComfyNode): ) -> IO.NodeOutput: if image is None and model == "grok-imagine-video-1.5": raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.") + if resolution == "1080p" and model != "grok-imagine-video-1.5": + raise ValueError(f"1080p resolution is only available for grok-imagine-video-1.5, not '{model}'.") image_url = None if image is not None: if get_number_of_images(image) != 1: diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index c81d3503d..b27de2549 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -60,6 +60,12 @@ from comfy_api_nodes.apis.kling import ( OmniProImageRequest, OmniProReferences2VideoRequest, OmniProText2VideoRequest, + Kling3TurboSettings, + Kling3TurboText2VideoRequest, + Kling3TurboContent, + Kling3TurboImage2VideoRequest, + Kling3TurboCreateResponse, + Kling3TurboQueryResponse, TaskStatusResponse, TextToVideoWithAudioRequest, ) @@ -2847,6 +2853,67 @@ class MotionControl(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) +def build_turbo_shot_prompt(multi_prompt: list[MultiPromptEntry]) -> str: + """Render storyboard entries into the Turbo multi-shot prompt 'shot n, m, words; ...'.""" + return "; ".join(f"shot {i}, {int(e.duration)}, {e.prompt}" for i, e in enumerate(multi_prompt, 1)) + ";" + + +def _turbo_video_url(response: Kling3TurboQueryResponse) -> str: + """Extract the result video URL from a /tasks response (data[].outputs[] where type == 'video').""" + task = response.data[0] if response.data else None + if task and task.outputs: + for output in task.outputs: + if output.type == "video" and output.url: + return output.url + raise RuntimeError(f"Kling 3.0 Turbo task finished without a video output: {response.model_dump()}") + + +async def execute_kling_turbo( + cls: type[IO.ComfyNode], + *, + prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + start_frame: torch.Tensor | None, +) -> IO.NodeOutput: + """Create + poll a Kling 3.0 Turbo task. Image-to-video when start_frame is given, else text-to-video.""" + if start_frame is not None: + validate_image_dimensions(start_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1)) + contents = [Kling3TurboContent(type="first_frame", url=tensor_to_base64_string(start_frame))] + if prompt: + contents.insert(0, Kling3TurboContent(type="prompt", text=prompt)) + create = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/image-to-video/kling-3.0-turbo", method="POST"), + response_model=Kling3TurboCreateResponse, + data=Kling3TurboImage2VideoRequest( + contents=contents, + settings=Kling3TurboSettings(resolution=resolution, duration=duration), # i2v: no aspect_ratio + ), + ) + else: + create = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/text-to-video/kling-3.0-turbo", method="POST"), + response_model=Kling3TurboCreateResponse, + data=Kling3TurboText2VideoRequest( + prompt=prompt, + settings=Kling3TurboSettings(resolution=resolution, aspect_ratio=aspect_ratio, duration=duration), + ), + ) + if not (create.data and create.data.id): + raise RuntimeError(f"Kling 3.0 Turbo create failed. Code: {create.code}, Message: {create.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path="/proxy/kling/tasks", query_params={"task_ids": create.data.id}), + response_model=Kling3TurboQueryResponse, + status_extractor=lambda r: (r.data[0].status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(_turbo_video_url(final_response))) + + class KlingVideoNode(IO.ComfyNode): @classmethod @@ -2884,7 +2951,11 @@ class KlingVideoNode(IO.ComfyNode): ], tooltip="Generate a series of video segments with individual prompts and durations.", ), - IO.Boolean.Input("generate_audio", default=True), + IO.Boolean.Input( + "generate_audio", + default=True, + tooltip="'kling-3.0-turbo' always generates native audio, so the audio toggle is ignored.", + ), IO.DynamicCombo.Input( "model", options=[ @@ -2899,6 +2970,17 @@ class KlingVideoNode(IO.ComfyNode): ), ], ), + IO.DynamicCombo.Option( + "kling-3.0-turbo", + [ + IO.Combo.Input("resolution", options=["1080p", "720p"], default="720p"), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "1:1"], + tooltip="Ignored in image-to-video mode.", + ), + ], + ), ], tooltip="Model and generation settings.", ), @@ -2930,6 +3012,7 @@ class KlingVideoNode(IO.ComfyNode): price_badge=IO.PriceBadge( depends_on=IO.PriceBadgeDepends( widgets=[ + "model", "model.resolution", "generate_audio", "multi_shot", @@ -2944,14 +3027,7 @@ class KlingVideoNode(IO.ComfyNode): ), expr=""" ( - $rates := { - "4k": {"off": 0.42, "on": 0.42}, - "1080p": {"off": 0.112, "on": 0.168}, - "720p": {"off": 0.084, "on": 0.126} - }; $res := $lookup(widgets, "model.resolution"); - $audio := widgets.generate_audio ? "on" : "off"; - $rate := $lookup($lookup($rates, $res), $audio); $ms := widgets.multi_shot; $isSb := $ms != "disabled"; $n := $isSb ? $number($substring($ms, 0, 1)) : 0; @@ -2962,7 +3038,18 @@ class KlingVideoNode(IO.ComfyNode): $d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0; $d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0; $dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration"); - {"type":"usd","usd": $rate * $dur} + widgets.model = "kling-3.0-turbo" + ? {"type":"usd","usd": ($res = "1080p" ? 0.14 : 0.112) * $dur} + : ( + $rates := { + "4k": {"off": 0.42, "on": 0.42}, + "1080p": {"off": 0.112, "on": 0.168}, + "720p": {"off": 0.084, "on": 0.126} + }; + $audio := widgets.generate_audio ? "on" : "off"; + $rate := $lookup($lookup($rates, $res), $audio); + {"type":"usd","usd": $rate * $dur} + ) ) """, ), @@ -3015,6 +3102,17 @@ class KlingVideoNode(IO.ComfyNode): duration = multi_shot["duration"] validate_string(multi_shot["prompt"], min_length=1, max_length=2500) + if model["model"] == "kling-3.0-turbo": + turbo_prompt = build_turbo_shot_prompt(multi_prompt_list) if custom_multi_shot else multi_shot["prompt"] + return await execute_kling_turbo( + cls, + prompt=turbo_prompt, + resolution=model["resolution"], + aspect_ratio=model["aspect_ratio"], + duration=duration, + start_frame=start_frame, + ) + if start_frame is not None: validate_image_dimensions(start_frame, min_width=300, min_height=300) validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1)) diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 0d31ac77e..cdfa32d8b 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -3,9 +3,13 @@ from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.luma import ( + LUMA_KEYFRAME_MODE_FRACTION, + LUMA_KEYFRAME_MODE_SECONDS, Luma2Generation, Luma2GenerationRequest, Luma2ImageRef, + Luma2VideoEdit, + Luma2VideoOptions, LumaAspectRatio, LumaCharacterRef, LumaConceptChain, @@ -18,6 +22,8 @@ from comfy_api_nodes.apis.luma import ( LumaIO, LumaKeyframes, LumaModifyImageRef, + LumaRay32KeyframeChain, + LumaRay32KeyframeItem, LumaReference, LumaReferenceChain, LumaVideoModel, @@ -33,6 +39,7 @@ from comfy_api_nodes.util import ( sync_op, upload_image_to_comfyapi, upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_string, ) @@ -692,7 +699,10 @@ async def _luma2_upload_image_refs( async def _luma2_submit_and_poll( cls: type[IO.ComfyNode], request: Luma2GenerationRequest, -) -> Input.Image: + *, + estimated_duration: int | None = None, +) -> Luma2Generation: + """Submit a Luma Agents generation and poll until done; returns the completed generation.""" initial = await sync_op( cls, ApiEndpoint(path="/proxy/luma_2/generations", method="POST"), @@ -700,21 +710,21 @@ async def _luma2_submit_and_poll( data=request, ) if not initial.id: - raise RuntimeError("Luma 2 API did not return a generation id.") + raise RuntimeError("Luma API did not return a generation id.") final = await poll_op( cls, ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"), response_model=Luma2Generation, status_extractor=lambda r: r.state, progress_extractor=lambda r: None, + estimated_duration=estimated_duration, ) - if not final.output: + if not final.output or not final.output[0].url: msg = final.failure_reason or "no output returned" - raise RuntimeError(f"Luma 2 generation failed: {msg}") - url = final.output[0].url - if not url: - raise RuntimeError("Luma 2 generation completed without an output URL.") - return await download_url_to_image_tensor(url) + if final.failure_code: + msg = f"{msg} [{final.failure_code}]" + raise RuntimeError(f"Luma generation failed: {msg}") + return final class LumaImageNode(IO.ComfyNode): @@ -843,7 +853,8 @@ class LumaImageNode(IO.ComfyNode): web_search=model["web_search"], image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9), ) - return IO.NodeOutput(await _luma2_submit_and_poll(cls, request)) + final = await _luma2_submit_and_poll(cls, request) + return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url)) class LumaImageEditNode(IO.ComfyNode): @@ -929,7 +940,533 @@ class LumaImageEditNode(IO.ComfyNode): web_search=model["web_search"], image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8), ) - return IO.NodeOutput(await _luma2_submit_and_poll(cls, request)) + final = await _luma2_submit_and_poll(cls, request) + return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url)) + + +_BADGE_RAY32_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]), + expr=""" + ( + $p := { + "360p": {"5s": 0.06, "10s": 0.18}, + "540p": {"5s": 0.15, "10s": 0.45}, + "720p": {"5s": 0.3, "10s": 0.9}, + "1080p": {"5s": 1.2, "10s": 3.6} + }; + {"type": "usd", "usd": $lookup($lookup($p, widgets.resolution), widgets.duration)} + ) + """, +) + +_BADGE_RAY32_VIDEO_5S = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $p := {"360p": 0.06, "540p": 0.15, "720p": 0.3, "1080p": 1.2}; + {"type": "usd", "usd": $lookup($p, widgets.resolution)} + ) + """, +) + +_BADGE_RAY32_EDIT = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $p := { + "360p": {"min": 0.54, "max": 1.08}, + "540p": {"min": 0.72, "max": 1.44}, + "720p": {"min": 1.08, "max": 2.16}, + "1080p": {"min": 2.16, "max": 4.32} + }; + $r := $lookup($p, widgets.resolution); + {"type": "range_usd", "min_usd": $r.min, "max_usd": $r.max, "format": {"note": "(by source length)"}} + ) + """, +) + +_BADGE_RAY32_REFRAME = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $p := {"360p": 0.03, "540p": 0.06, "720p": 0.12, "1080p": 0.36}; + {"type": "usd", "usd": $lookup($p, widgets.resolution), "format": {"suffix": "/second"}} + ) + """, +) + + +def _ray32_seed_input() -> IO.Input: + return IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; results are nondeterministic regardless of seed.", + ) + + +async def _ray32_generate(cls: type[IO.ComfyNode], request: Luma2GenerationRequest) -> IO.NodeOutput: + """Run a ray-3.2 generation and return (video, generation_id).""" + final = await _luma2_submit_and_poll(cls, request, estimated_duration=120) + video = await download_url_to_video_output(final.output[0].url) + return IO.NodeOutput(video, final.id or "") + + +class LumaRay32TextToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32TextToVideoNode", + display_name="Luma Ray 3.2 Text to Video", + category="partner/video/Luma", + description="Generate a video from a text prompt using Luma's Ray 3.2 model.", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + IO.Combo.Input("duration", options=["5s", "10s"]), + IO.Boolean.Input( + "loop", + default=False, + tooltip="Make the video loop seamlessly. Only available with 5s duration.", + ), + _ray32_seed_input(), + ], + outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_VIDEO, + ) + + @classmethod + async def execute( + cls, prompt: str, aspect_ratio: str, resolution: str, duration: str, loop: bool, seed: int + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000) + if loop and duration == "10s": + raise ValueError("Looping is only available with 5s duration on Ray 3.2.") + request = Luma2GenerationRequest( + prompt=prompt, + model="ray-3.2", + type="video", + aspect_ratio=aspect_ratio, + video=Luma2VideoOptions(resolution=resolution, duration=duration, loop=loop or None), + ) + return await _ray32_generate(cls, request) + + +class LumaRay32ImageToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32ImageToVideoNode", + display_name="Luma Ray 3.2 Image to Video", + category="partner/video/Luma", + description="Generate a video from a start and/or end frame using Luma's Ray 3.2 model. " + "Image-anchored generations are always 5 seconds.", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + IO.Boolean.Input( + "loop", + default=False, + tooltip="Make the video loop seamlessly. Not available when an end_frame is set.", + ), + _ray32_seed_input(), + IO.Image.Input("start_frame", optional=True, tooltip="First frame of the generated video."), + IO.Image.Input("end_frame", optional=True, tooltip="Last frame of the generated video."), + ], + outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_VIDEO_5S, + ) + + @classmethod + async def execute( + cls, + prompt: str, + resolution: str, + loop: bool, + seed: int, + start_frame: torch.Tensor | None = None, + end_frame: torch.Tensor | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000) + if start_frame is None and end_frame is None: + raise ValueError("Provide at least one of start_frame / end_frame.") + if loop and end_frame is not None: + raise ValueError("Looping is not available when an end_frame is set.") + video = Luma2VideoOptions(resolution=resolution, duration="5s", loop=loop or None) + if start_frame is not None: + url = await upload_image_to_comfyapi(cls, start_frame, mime_type="image/png") + video.start_frame = Luma2ImageRef(url=url) + if end_frame is not None: + url = await upload_image_to_comfyapi(cls, end_frame, mime_type="image/png") + video.end_frame = Luma2ImageRef(url=url) + request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video) + return await _ray32_generate(cls, request) + + +class LumaRay32KeyframeNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32KeyframeNode", + display_name="Luma Ray 3.2 Keyframe", + category="partner/video/Luma", + description="Anchor a guide image to a position on the Ray 3.2 output video timeline. Connect this to " + "the 'keyframes' input of the Luma Ray 3.2 Keyframes to Video node; chain several together via the " + "optional 'keyframes' input below.", + inputs=[ + IO.Image.Input("image", tooltip="Guide image to place at the chosen moment of the output video."), + IO.DynamicCombo.Input( + "position", + options=[ + IO.DynamicCombo.Option( + "Fraction of duration (0.0-1.0)", + [ + IO.Float.Input( + "fraction", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Where in the output video this image applies " "(0.0 = start, 1.0 = end).", + ), + ], + ), + IO.DynamicCombo.Option( + "Absolute time (seconds)", + [ + IO.Float.Input( + "seconds", + default=0.0, + min=0.0, + max=10.0, + step=0.1, + display_mode=IO.NumberDisplay.number, + tooltip="Time in seconds from the start of the output video where this " + "image applies.", + ), + ], + ), + ], + tooltip="How to place this image on the output video's timeline.", + ), + IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input( + "keyframes", + optional=True, + tooltip="Optional earlier keyframes to chain with this one.", + ), + ], + outputs=[IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Output(display_name="keyframes")], + ) + + @classmethod + def execute( + cls, + image: torch.Tensor, + position: dict, + keyframes: LumaRay32KeyframeChain | None = None, + ) -> IO.NodeOutput: + chain = keyframes.clone() if keyframes is not None else LumaRay32KeyframeChain() + if position["position"] == "Absolute time (seconds)": + mode, value = LUMA_KEYFRAME_MODE_SECONDS, float(position["seconds"]) + else: + mode, value = LUMA_KEYFRAME_MODE_FRACTION, float(position["fraction"]) + chain.add(LumaRay32KeyframeItem(image=image, mode=mode, value=value)) + return IO.NodeOutput(chain) + + +class LumaRay32KeyframesToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32KeyframesToVideoNode", + display_name="Luma Ray 3.2 Keyframes to Video", + category="partner/video/Luma", + description="Generate a video that interpolates through a sequence of guide images, each anchored to a " + "position on the timeline, using Luma Ray 3.2. Build the sequence with Luma Ray 3.2 Keyframe nodes " + "(at least 2).", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + IO.Combo.Input("duration", options=["5s", "10s"]), + _ray32_seed_input(), + IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input( + "keyframes", + tooltip="Keyframe sequence from Luma Ray 3.2 Keyframe nodes (at least 2).", + ), + ], + outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_VIDEO, + ) + + @classmethod + async def execute( + cls, + prompt: str, + resolution: str, + duration: str, + seed: int, + keyframes: LumaRay32KeyframeChain | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000) + items = keyframes.items if keyframes is not None else [] + if len(items) < 2: + raise ValueError( + "Connect at least 2 Luma Ray 3.2 Keyframe nodes " + "(use Luma Ray 3.2 Image to Video for a single start/end frame)." + ) + if len(items) > 64: + raise ValueError(f"Ray 3.2 supports at most 64 keyframes; got {len(items)}.") + maxframe = 120 if duration == "5s" else 240 + duration_seconds = maxframe / 24 # 5.0 or 10.0 + # Resolve each keyframe to an output-frame index, then order by position + # (so the user can chain keyframes in any order — the position is what places them) + placed: list[tuple[int, torch.Tensor]] = [] + for item in items: + if item.mode == LUMA_KEYFRAME_MODE_SECONDS: + if item.value > duration_seconds: + raise ValueError( + f"Keyframe position {item.value:g}s is past the end of the {duration} video; " + f"use 0-{duration_seconds:g}s (or switch the keyframe to fraction mode)." + ) + idx = round(item.value * 24) + else: + idx = round(item.value * maxframe) + placed.append((max(0, min(maxframe, idx)), item.image)) + placed.sort(key=lambda p: p[0]) + indexes = [idx for idx, _ in placed] + for a, b in zip(indexes, indexes[1:]): + if a == b: + raise ValueError( + f"Two keyframes resolve to the same output frame ({a}) for a {duration} video " + f"(valid range 0-{maxframe}); give each keyframe a distinct position." + ) + refs: list[Luma2ImageRef] = [] + for _, image in placed: + url = await upload_image_to_comfyapi(cls, image, mime_type="image/png") + refs.append(Luma2ImageRef(url=url)) + request = Luma2GenerationRequest( + prompt=prompt, + model="ray-3.2", + type="video", + video=Luma2VideoOptions(resolution=resolution, duration=duration, keyframes=refs, keyframe_indexes=indexes), + ) + return await _ray32_generate(cls, request) + + +class LumaRay32VideoEditNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32VideoEditNode", + display_name="Luma Ray 3.2 Video Edit", + category="partner/video/Luma", + description="Re-render an existing video under a new prompt using Luma Ray 3.2 (restyle, relight, add " + "or remove elements) while keeping the original motion. Source video up to 18 seconds; the edited " + "video keeps the source's length.", + inputs=[ + IO.Video.Input("video", tooltip="Source video to edit. Up to 18 seconds."), + IO.String.Input("prompt", multiline=True, default="", tooltip="Describes the desired edit."), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + IO.Combo.Input( + "strength", + options=[ + "auto", + "adhere_1", + "adhere_2", + "adhere_3", + "flex_1", + "flex_2", + "flex_3", + "reimagine_1", + "reimagine_2", + "reimagine_3", + ], + default="auto", + tooltip="How strongly to preserve vs. reimagine the source. 'auto' lets Ray 3.2 choose; " + "adhere_* preserves the most, flex_* is balanced, reimagine_* changes the most.", + ), + _ray32_seed_input(), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="generation_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_EDIT, + ) + + @classmethod + async def execute( + cls, video: Input.Video, prompt: str, resolution: str, strength: str, seed: int + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000) + try: + duration = "5s" if video.get_duration() <= 5.0 else "10s" + except Exception: + duration = "10s" + source_url = await upload_video_to_comfyapi(cls, video, max_duration=18) + edit = Luma2VideoEdit(auto_controls=True) if strength == "auto" else Luma2VideoEdit(strength=strength) + request = Luma2GenerationRequest( + prompt=prompt, + model="ray-3.2", + type="video_edit", + source=Luma2ImageRef(url=source_url, media_type="video/mp4"), + video=Luma2VideoOptions(resolution=resolution, duration=duration, edit=edit), + ) + return await _ray32_generate(cls, request) + + +class LumaRay32VideoReframeNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32VideoReframeNode", + display_name="Luma Ray 3.2 Video Reframe", + category="partner/video/Luma", + description="Change the aspect ratio of an existing video, using Luma Ray 3.2 to fill the newly " + "exposed canvas areas. Source video up to 30 seconds. Billed per second of output.", + inputs=[ + IO.Video.Input("video", tooltip="Source video to reframe. Up to 30 seconds."), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Describes how the newly exposed canvas areas should be filled.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]), + IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"), + _ray32_seed_input(), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="generation_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_REFRAME, + ) + + @classmethod + async def execute( + cls, video: Input.Video, prompt: str, aspect_ratio: str, resolution: str, seed: int + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000) + if resolution == "1080p" and aspect_ratio in {"9:16", "3:4"}: + raise ValueError("1080p is not available for vertical aspect ratios (9:16, 3:4) when reframing.") + source_url = await upload_video_to_comfyapi(cls, video, max_duration=30) + request = Luma2GenerationRequest( + prompt=prompt, + model="ray-3.2", + type="video_reframe", + aspect_ratio=aspect_ratio, + source=Luma2ImageRef(url=source_url, media_type="video/mp4"), + video=Luma2VideoOptions(resolution=resolution), + ) + return await _ray32_generate(cls, request) + + +class LumaRay32ExtendVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaRay32ExtendVideoNode", + display_name="Luma Ray 3.2 Extend Video", + category="partner/video/Luma", + description="Extend a previous Ray 3.2 generation forward (continue after it) or backward (lead-in " + "before it). Connect the generation_id output of a prior Luma Ray 3.2 node." + " Extensions are always 5 seconds.", + inputs=[ + IO.String.Input( + "source_generation_id", + default="", + tooltip="generation_id of the prior Ray 3.2 video to extend." + " Connect the generation_id output of another Luma Ray 3.2 node.", + ), + IO.DynamicCombo.Input( + "direction", + options=[ + IO.DynamicCombo.Option( + "Forward (continue after)", + [ + IO.Boolean.Input( + "loop", + default=False, + tooltip="Loop the extended video seamlessly (forward extend only).", + ), + ], + ), + IO.DynamicCombo.Option("Backward (lead-in before)", []), + ], + tooltip="Forward continues after the prior clip; backward is prepended before it.", + ), + IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the new content."), + IO.Combo.Input("resolution", options=["540p", "720p", "1080p"], default="720p"), + _ray32_seed_input(), + ], + outputs=[ + IO.Video.Output(), + IO.String.Output(display_name="generation_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=_BADGE_RAY32_VIDEO_5S, + ) + + @classmethod + async def execute( + cls, source_generation_id: str, direction: dict, prompt: str, resolution: str, seed: int + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000) + gen_id = (source_generation_id or "").strip() + if not gen_id: + raise ValueError( + "source_generation_id is required (connect the generation_id output of a prior Luma Ray 3.2 node)." + ) + video = Luma2VideoOptions(resolution=resolution, duration="5s") + ref = Luma2ImageRef(generation_id=gen_id) + if direction["direction"] == "Forward (continue after)": + video.start_frame = ref + if direction.get("loop"): + video.loop = True + else: + video.end_frame = ref + request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video) + return await _ray32_generate(cls, request) class LumaExtension(ComfyExtension): @@ -944,6 +1481,13 @@ class LumaExtension(ComfyExtension): LumaConceptsNode, LumaImageNode, LumaImageEditNode, + LumaRay32TextToVideoNode, + LumaRay32ImageToVideoNode, + LumaRay32KeyframeNode, + LumaRay32KeyframesToVideoNode, + LumaRay32VideoEditNode, + LumaRay32VideoReframeNode, + LumaRay32ExtendVideoNode, ] diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index b7b97d70f..1782739fd 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -48,10 +48,13 @@ from comfy_api_nodes.util import ( upload_image_to_comfyapi, upload_video_to_comfyapi, validate_audio_duration, + validate_image_aspect_ratio, + validate_image_dimensions, validate_string, validate_video_duration, ) + RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)") @@ -1657,6 +1660,44 @@ class HappyHorseTextToVideoApi(IO.ComfyNode): IO.DynamicCombo.Input( "model", options=[ + IO.DynamicCombo.Option( + "happyhorse-1.1-t2v", + [ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. " + "Supports English and Chinese.", + ), + IO.Combo.Input( + "resolution", + options=["720P", "1080P"], + ), + IO.Combo.Input( + "ratio", + options=[ + "16:9", + "9:16", + "1:1", + "4:3", + "3:4", + "21:9", + "9:21", + "5:4", + "4:5", + ], + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + step=1, + display_mode=IO.NumberDisplay.number, + ), + ], + ), IO.DynamicCombo.Option( "happyhorse-1.0-t2v", [ @@ -1719,7 +1760,9 @@ class HappyHorseTextToVideoApi(IO.ComfyNode): ( $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $ppsTable := { "720p": 0.14, "1080p": 0.24 }; + $ppsTable := $contains(widgets.model, "1.1") + ? { "720p": 0.2002, "1080p": 0.2574 } + : { "720p": 0.14, "1080p": 0.24 }; $pps := $lookup($ppsTable, $res); { "type": "usd", "usd": $pps * $dur } ) @@ -1781,6 +1824,30 @@ class HappyHorseImageToVideoApi(IO.ComfyNode): IO.DynamicCombo.Input( "model", options=[ + IO.DynamicCombo.Option( + "happyhorse-1.1-i2v", + [ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. " + "Supports English and Chinese.", + ), + IO.Combo.Input( + "resolution", + options=["720P", "1080P"], + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + step=1, + display_mode=IO.NumberDisplay.number, + ), + ], + ), IO.DynamicCombo.Option( "happyhorse-1.0-i2v", [ @@ -1843,7 +1910,9 @@ class HappyHorseImageToVideoApi(IO.ComfyNode): ( $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $ppsTable := { "720p": 0.14, "1080p": 0.24 }; + $ppsTable := $contains(widgets.model, "1.1") + ? { "720p": 0.2002, "1080p": 0.2574 } + : { "720p": 0.14, "1080p": 0.24 }; $pps := $lookup($ppsTable, $res); { "type": "usd", "usd": $pps * $dur } ) @@ -1859,6 +1928,8 @@ class HappyHorseImageToVideoApi(IO.ComfyNode): seed: int, watermark: bool, ): + validate_image_dimensions(first_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1), strict=False) media = [ Wan27MediaItem( type="first_frame", @@ -2053,6 +2124,62 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode): IO.DynamicCombo.Input( "model", options=[ + IO.DynamicCombo.Option( + "happyhorse-1.1-r2v", + [ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the video. Use identifiers such as 'character1' and " + "'character2' to refer to the reference characters.", + ), + IO.Combo.Input( + "resolution", + options=["720P", "1080P"], + ), + IO.Combo.Input( + "ratio", + options=[ + "16:9", + "9:16", + "1:1", + "4:3", + "3:4", + "21:9", + "9:21", + "5:4", + "4:5", + ], + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + step=1, + display_mode=IO.NumberDisplay.number, + ), + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("reference_image"), + names=[ + "image1", + "image2", + "image3", + "image4", + "image5", + "image6", + "image7", + "image8", + "image9", + ], + min=1, + ), + ), + ], + ), IO.DynamicCombo.Option( "happyhorse-1.0-r2v", [ @@ -2133,7 +2260,9 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode): ( $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $ppsTable := { "720p": 0.14, "1080p": 0.24 }; + $ppsTable := $contains(widgets.model, "1.1") + ? { "720p": 0.2002, "1080p": 0.2574 } + : { "720p": 0.14, "1080p": 0.24 }; $pps := $lookup($ppsTable, $res); { "type": "usd", "usd": $pps * $dur } ) @@ -2149,8 +2278,11 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode): watermark: bool, ): validate_string(model["prompt"], strip_whitespace=False, min_length=1) - media = [] reference_images = model.get("reference_images", {}) + for key in reference_images: + validate_image_dimensions(reference_images[key], min_width=400, min_height=400) + validate_image_aspect_ratio(reference_images[key], (1, 2.5), (2.5, 1), strict=False) + media = [] for key in reference_images: media.append( Wan27MediaItem( @@ -2159,7 +2291,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode): ) ) if not media: - raise ValueError("At least one reference reference image must be provided.") + raise ValueError("At least one reference image must be provided.") initial_response = await sync_op( cls, diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 83cf7b001..6b8121cab 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -4,6 +4,8 @@ import os import re import time from collections.abc import Callable +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime from io import BytesIO from yarl import URL @@ -91,6 +93,32 @@ async def sleep_with_interrupt( await asyncio.sleep(min(1.0, end - now)) +def _retry_after_wait(value: str | None, fallback: float, max_wait: float) -> float: + """Delay before the next retry, honoring a server ``Retry-After`` header.""" + + seconds: float | None = None + if value is not None: + value = value.strip() + if value.isascii() and value.isdigit(): + # delay-seconds form. The ASCII-digit guard keeps exotic Unicode "digit" characters away from float() + # an all-digit string always converts (huge values become inf, never raising). + seconds = float(value) + elif value: + # HTTP-date form. parsedate_to_datetime raises OverflowError (not a ValueError) on absurd years/offsets + try: + parsed = parsedate_to_datetime(value) + except (TypeError, ValueError, OverflowError): + parsed = None + if parsed is not None: + if parsed.tzinfo is None: # naive datetime: HTTP-date is UTC + parsed = parsed.replace(tzinfo=timezone.utc) + delta = (parsed - datetime.now(timezone.utc)).total_seconds() + seconds = delta if delta > 0 else 0.0 + if seconds is None: + return fallback + return min(seconds, max_wait) + + def mimetype_to_extension(mime_type: str) -> str: """Converts a MIME type to a file extension.""" return mime_type.split("/")[-1].lower() diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index adcde7bcb..66aab17f8 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -21,6 +21,7 @@ from server import PromptServer from . import request_logger from ._helpers import ( + _retry_after_wait, default_base_url, get_comfy_api_headers, get_node_id, @@ -82,6 +83,7 @@ class _PollUIState: _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately +_MAX_RETRY_AFTER_WAIT = 150.0 # Cap a server Retry-After at this many seconds so a large hint can't block execution COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"] @@ -747,6 +749,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): should_retry = True if should_retry: + wait_time = _retry_after_wait(resp.headers.get("Retry-After"), wait_time, _MAX_RETRY_AFTER_WAIT) logging.warning( "HTTP %s %s -> %s. Waiting %.2fs (%s).", method, diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 20ebae155..fa3ab0faf 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -4,11 +4,22 @@ Provides normalization and helper functions for job status tracking. """ import uuid -from typing import Optional +from typing import Callable, Optional from comfy_api.internal import prune_dict +# Result of classifying a job for cancellation. +# 'running' -> job is currently executing (interrupt it) +# 'pending' -> job is queued but not started (dequeue it) +# 'terminal' -> job already finished (present in history); cancel is a no-op +# 'unknown' -> job id is not present anywhere +CANCEL_RUNNING = 'running' +CANCEL_PENDING = 'pending' +CANCEL_TERMINAL = 'terminal' +CANCEL_UNKNOWN = 'unknown' + + class JobStatus: """Job status constants.""" PENDING = 'pending' @@ -407,3 +418,71 @@ def get_all_jobs( jobs = jobs[:limit] return (jobs, total_count) + + +def classify_job_for_cancel(prompt_id: str, running: list, queued: list, history: dict) -> str: + """Classify a job id for cancellation. + + Returns one of CANCEL_RUNNING, CANCEL_PENDING, CANCEL_TERMINAL, CANCEL_UNKNOWN. + + Queue items are tuples whose second element (index 1) is the prompt_id. + History is a dict keyed by prompt_id, so a job present there has already + finished and cancelling it is a no-op. + """ + for item in running: + if item[1] == prompt_id: + return CANCEL_RUNNING + for item in queued: + if item[1] == prompt_id: + return CANCEL_PENDING + if prompt_id in history: + return CANCEL_TERMINAL + return CANCEL_UNKNOWN + + +def cancel_job( + prompt_id: str, + running: list, + queued: list, + history: dict, + interrupt: Callable[[str], bool], + dequeue: Callable[[str], bool], +) -> str: + """Cancel a single job by id, regardless of state. + + Maps the cancel onto the runtime's existing mechanics: + - a running job is interrupted via ``interrupt`` + - a pending job is removed from the queue via ``dequeue`` + - a job that already finished (terminal) is a no-op + - an unknown id is a no-op (callers that need fail-fast behaviour should + validate ids up front with ``classify_job_for_cancel``) + + Both ``interrupt`` and ``dequeue`` take the prompt id and return whether + they acted on a job that was *actually* in that state, so the value returned + here reflects what truly happened rather than the (possibly stale) + classification. This matters around the narrow TOCTOU windows where a job + changes state between the caller's snapshot and the action: + + - a job classified RUNNING may have finished before ``interrupt`` fires: + ``interrupt`` returns False and this returns CANCEL_UNKNOWN (no-op). + - a job classified PENDING may have started executing before ``dequeue`` + fires: ``dequeue`` returns False, ``interrupt`` then catches the now- + running job and this returns CANCEL_RUNNING. If it had simply finished + instead, both return False and this returns CANCEL_UNKNOWN. + + ``interrupt`` must be atomic — interrupt the job only if it is still the one + running — so a cancel can never land on an unrelated prompt that started in + the meantime (see ``execution.PromptQueue.interrupt_if_running``). + """ + classification = classify_job_for_cancel(prompt_id, running, queued, history) + if classification == CANCEL_RUNNING: + return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN + if classification == CANCEL_PENDING: + if dequeue(prompt_id): + return CANCEL_PENDING + # Left the pending queue between classification and dequeue: if it + # started executing, interrupt the now-running job; otherwise it has + # already finished and the cancel is a genuine no-op. + return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN + # CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops. + return classification diff --git a/comfy_extras/color_util.py b/comfy_extras/color_util.py new file mode 100644 index 000000000..d50795ae3 --- /dev/null +++ b/comfy_extras/color_util.py @@ -0,0 +1,23 @@ +def hex_to_rgb(value: str) -> tuple[int, int, int]: + h = value.lstrip("#") + if len(h) != 6: + return (255, 255, 255) + try: + return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)) + except ValueError: + return (255, 255, 255) + + +def readable_color(rgb: tuple[int, int, int]) -> tuple[int, int, int]: + r, g, b = rgb + lum = 0.299 * r + 0.587 * g + 0.114 * b + if lum >= 130: + return (r, g, b) + t = (130 - lum) / (255 - lum) + return (round(r + (255 - r) * t), round(g + (255 - g) * t), round(b + (255 - b) * t)) + + +def normalize_palette(colors) -> list[str]: + if isinstance(colors, dict): + colors = colors.values() + return [c.upper() for c in colors if isinstance(c, str) and c] diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 77f124e28..6adcc95fa 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -158,7 +158,7 @@ class SaveAudio(IO.ComfyNode): return IO.Schema( node_id="SaveAudio", search_aliases=["export flac"], - display_name="Save Audio (FLAC) (Deprecated)", + display_name="Save Audio (FLAC) (DEPRECATED)", category="audio", essentials_category="Audio", inputs=[ @@ -166,8 +166,9 @@ class SaveAudio(IO.ComfyNode): IO.String.Input("filename_prefix", default="audio/ComfyUI"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, is_deprecated=True, + is_output_node=True, + outputs=[IO.Audio.Output("audio")] ) @classmethod @@ -175,11 +176,10 @@ class SaveAudio(IO.ComfyNode): if audio is None: raise ValueError("SaveAudio: input audio is None (source video may have no audio track).") return IO.NodeOutput( + audio, ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format) ) - save_flac = execute # TODO: remove - class SaveAudioMP3(IO.ComfyNode): @classmethod @@ -187,7 +187,7 @@ class SaveAudioMP3(IO.ComfyNode): return IO.Schema( node_id="SaveAudioMP3", search_aliases=["export mp3"], - display_name="Save Audio (MP3) (Deprecated)", + display_name="Save Audio (MP3) (DEPRECATED)", category="audio", essentials_category="Audio", inputs=[ @@ -196,8 +196,9 @@ class SaveAudioMP3(IO.ComfyNode): IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, is_deprecated=True, + is_output_node=True, + outputs=[IO.Audio.Output("audio")] ) @classmethod @@ -205,13 +206,12 @@ class SaveAudioMP3(IO.ComfyNode): if audio is None: raise ValueError("SaveAudioMP3: input audio is None (source video may have no audio track).") return IO.NodeOutput( + audio, ui=UI.AudioSaveHelper.get_save_audio_ui( audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality ) ) - save_mp3 = execute # TODO: remove - class SaveAudioOpus(IO.ComfyNode): @classmethod @@ -219,7 +219,7 @@ class SaveAudioOpus(IO.ComfyNode): return IO.Schema( node_id="SaveAudioOpus", search_aliases=["export opus"], - display_name="Save Audio (Opus) (Deprecated)", + display_name="Save Audio (Opus) (DEPRECATED)", category="audio", inputs=[ IO.Audio.Input("audio"), @@ -227,8 +227,9 @@ class SaveAudioOpus(IO.ComfyNode): IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, is_deprecated=True, + is_output_node=True, + outputs=[IO.Audio.Output("audio")] ) @classmethod @@ -236,13 +237,12 @@ class SaveAudioOpus(IO.ComfyNode): if audio is None: raise ValueError("SaveAudioOpus: input audio is None (source video may have no audio track).") return IO.NodeOutput( + audio, ui=UI.AudioSaveHelper.get_save_audio_ui( audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality ) ) - save_opus = execute # TODO: remove - class SaveAudioAdvanced(IO.ComfyNode): @classmethod @@ -258,10 +258,7 @@ class SaveAudioAdvanced(IO.ComfyNode): IO.String.Input( "filename_prefix", default="audio/ComfyUI", - tooltip=( - "The prefix for the file to save. May include formatting tokens " - "such as %date:yyyy-MM-dd%." - ), + tooltip=("The prefix for the file to save. May include formatting tokens such as %date:yyyy-MM-dd%."), ), IO.DynamicCombo.Input( "format", @@ -279,6 +276,7 @@ class SaveAudioAdvanced(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Audio.Output("audio")], ) @classmethod @@ -289,7 +287,7 @@ class SaveAudioAdvanced(IO.ComfyNode): ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format, quality=quality) else: ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format) - return IO.NodeOutput(ui=ui) + return IO.NodeOutput(audio, ui=ui) class PreviewAudio(IO.ComfyNode): @@ -305,13 +303,14 @@ class PreviewAudio(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Audio.Output("audio")] ) @classmethod def execute(cls, audio) -> IO.NodeOutput: if audio is None: raise ValueError("PreviewAudio: input audio is None (source video may have no audio track).") - return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls)) + return IO.NodeOutput(audio, ui=UI.PreviewAudio(audio, cls=cls)) save_flac = execute # TODO: remove diff --git a/comfy_extras/nodes_boogu.py b/comfy_extras/nodes_boogu.py new file mode 100644 index 000000000..f3951c290 --- /dev/null +++ b/comfy_extras/nodes_boogu.py @@ -0,0 +1,97 @@ +import math + +import node_helpers +import comfy.utils +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class TextEncodeBooguEdit(io.ComfyNode): + """Boogu-Image Edit conditioning. + + The edit image is used twice, matching the reference pipeline: + - Qwen3-VL vision tokens (instruction understanding) -> positive only + - VAE reference latent (image identity) -> positive and negative + The ref latent is in both conds so it cancels under CFG (identity preserved); + the vision tokens are only in the positive so CFG amplifies the instruction. + The tokenizer selects the right system prompt automatically (image -> TI2I, + empty negative -> DROP), so no template plumbing is needed here. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeBooguEdit", + category="model/conditioning/boogu", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.String.Input("negative_prompt", multiline=True, dynamic_prompts=True, advanced=True), + io.Vae.Input("vae"), + io.Autogrow.Input( + "images", + template=io.Autogrow.TemplateNames( + io.Image.Input("image"), + names=[f"image_{i}" for i in range(1, 17)], + min=0, + ), + tooltip="Reference image(s) to edit. Boogu focuses on one reference per sample; more are allowed.", + ), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, clip, prompt, negative_prompt, vae=None, images: io.Autogrow.Type = None) -> io.NodeOutput: + ref_latents = [] + images_vl = [] + + images = images or {} + for name in sorted(images, key=lambda n: int(n.rsplit("_", 1)[-1])): + image = images[name] + if image is None: + continue + samples = image.movedim(-1, 1) + + # Vision tower input: the reference caps the VLM image at 384x384 + # (max_vlm_input_pil_pixels in pipeline_boogu.py). + total = int(384 * 384) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + images_vl.append(s.movedim(1, -1)[:, :, :, :3]) + + # Reference latent: align to 16 px (VAE /8 * patch_size 2). + if vae is not None: + total = int(1024 * 1024) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by / 16.0) * 16 + height = round(samples.shape[2] * scale_by / 16.0) * 16 + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3])) + + # positive: instruction + vision tokens; negative: empty (no vision). Ref latent on both. + positive = clip.encode_from_tokens_scheduled(clip.tokenize(prompt, images=images_vl)) + negative = clip.encode_from_tokens_scheduled(clip.tokenize(negative_prompt)) + + if len(ref_latents) > 0: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": ref_latents}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": ref_latents}, append=True) + + return io.NodeOutput(positive, negative) + + +class BooguExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeBooguEdit, + ] + + +async def comfy_entrypoint() -> BooguExtension: + return BooguExtension() diff --git a/comfy_extras/nodes_bounding_boxes.py b/comfy_extras/nodes_bounding_boxes.py new file mode 100644 index 000000000..77cbf8649 --- /dev/null +++ b/comfy_extras/nodes_bounding_boxes.py @@ -0,0 +1,253 @@ +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageEnhance, ImageFont +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io +from comfy_extras.color_util import hex_to_rgb, normalize_palette, readable_color + +_PREVIEW_LONG_EDGE = 1024 +_PREVIEW_DIM = 0.25 + + +def pixels_to_fractions(box: dict, width: int, height: int) -> dict: + w = width or 1 + h = height or 1 + return { + "x": box.get("x", 0) / w, + "y": box.get("y", 0) / h, + "w": box.get("width", 0) / w, + "h": box.get("height", 0) / h, + } + + +def fractions_to_pixels(box: dict, width: int, height: int) -> dict: + x, y = box.get("x", 0.0), box.get("y", 0.0) + w, h = box.get("w", 0.0), box.get("h", 0.0) + if w < 0: + x, w = x + w, -w + if h < 0: + y, h = y + h, -h + return { + "x": round(x * width), + "y": round(y * height), + "width": round(w * width), + "height": round(h * height), + } + + +def fractions_to_bbox_frame(boxes: list, width: int, height: int) -> list: + pixels = [ + fractions_to_pixels(box, width, height) + for box in boxes + if isinstance(box, dict) + ] + return [pixels] if pixels else [] + + +def _font(size: int): + try: + return ImageFont.load_default(size) + except Exception: + return ImageFont.load_default() + + +def _wrap(draw, text: str, font, max_w: float) -> list[str]: + lines = [] + for para in text.split("\n"): + line = "" + for word in para.split(): + test = word if not line else line + " " + word + if line and draw.textlength(test, font=font) > max_w: + lines.append(line) + line = word + else: + line = test + lines.append(line) + return lines + + +def _bg_from_image(image) -> Image.Image | None: + if image is None: + return None + try: + arr = (image[0].detach().cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + return Image.fromarray(arr) + except Exception: + return None + + +def render_preview(regions, width, height, bg=None): + if bg is not None: + iw, ih = bg.size + long_edge = max(iw, ih) or 1 + scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge) + rw, rh = max(1, round(iw * scale)), max(1, round(ih * scale)) + base = bg.convert("RGB").resize((rw, rh), Image.LANCZOS) + base = ImageEnhance.Brightness(base).enhance(_PREVIEW_DIM) + img = base.convert("RGBA") + else: + long_edge = max(width, height) or 1 + scale = min(1.0, _PREVIEW_LONG_EDGE / long_edge) + rw, rh = max(1, round(width * scale)), max(1, round(height * scale)) + grey = round(_PREVIEW_DIM * 128) + img = Image.new("RGBA", (rw, rh), (grey, grey, grey, 255)) + + overlay = Image.new("RGBA", (rw, rh), (0, 0, 0, 0)) + draw = ImageDraw.Draw(overlay) + fs = max(10, round(rh / 64)) + font = _font(fs) + tag_font = _font(max(9, fs - 2)) + line_h = fs + 2 + + for i, region in enumerate(regions): + if not isinstance(region, dict): + continue + palette = [c for c in (region.get("palette") or []) if c] + r, g, b = hex_to_rgb(palette[0]) if palette else (140, 140, 140) + x1 = max(0, min(rw, round(region.get("x", 0) * rw))) + y1 = max(0, min(rh, round(region.get("y", 0) * rh))) + x2 = max(0, min(rw, round((region.get("x", 0) + region.get("w", 0)) * rw))) + y2 = max(0, min(rh, round((region.get("y", 0) + region.get("h", 0)) * rh))) + if x2 < x1: + x1, x2 = x2, x1 + if y2 < y1: + y1, y2 = y2, y1 + + draw.rectangle([x1, y1, x2, y2], outline=(r, g, b, 255), width=2) + + swatches = palette[:5] + if swatches and (x2 - x1) > 2: + sh = max(5, fs // 2) + seg = (x2 - x1) / len(swatches) + for p, hexc in enumerate(swatches): + sx = x1 + round(p * seg) + draw.rectangle([sx, y1, x1 + round((p + 1) * seg), y1 + sh], fill=hex_to_rgb(hexc)) + + etype = "text" if region.get("type") == "text" else "obj" + tag = str(i + 1).zfill(2) + tw = draw.textlength(tag, font=tag_font) + draw.rectangle([x1, y1, x1 + tw + 6, y1 + fs + 2], fill=(r, g, b, 255)) + tag_fill = (0, 0, 0, 255) if (0.299 * r + 0.587 * g + 0.114 * b) > 140 else (255, 255, 255, 255) + draw.text((x1 + 3, y1 + 1), tag, fill=tag_fill, font=tag_font) + + body = region.get("desc", "") or "" + if etype == "text" and region.get("text"): + body = '"%s"%s' % (region["text"], " — " + body if body else "") + if body and (x2 - x1) > 8: + ty = y1 + fs + 5 + for line in _wrap(draw, body, font, x2 - x1 - 8): + if ty > y2: + break + draw.text((x1 + 4, ty), line, fill=readable_color((r, g, b)) + (255,), font=font) + ty += line_h + + composed = Image.alpha_composite(img, overlay).convert("RGB") + arr = np.asarray(composed, dtype=np.float32) / 255.0 + return torch.from_numpy(arr).unsqueeze(0) + + +def boxes_to_regions(boxes, width: int, height: int) -> list: + regions: list = [] + if not isinstance(boxes, list): + return regions + for box in boxes: + if not isinstance(box, dict): + continue + meta = box.get("metadata") + meta = meta if isinstance(meta, dict) else {} + regions.append({ + **pixels_to_fractions(box, width, height), + "type": meta.get("type", "obj"), + "text": meta.get("text", ""), + "desc": meta.get("desc", ""), + "palette": meta.get("palette", []), + }) + return regions + + +def _norm_bbox(region: dict) -> list[int]: + def grid(value: float) -> int: + return max(0, min(1000, round(value * 1000))) + + x, y = region.get("x", 0.0), region.get("y", 0.0) + w, h = region.get("w", 0.0), region.get("h", 0.0) + ymin, xmin, ymax, xmax = grid(y), grid(x), grid(y + h), grid(x + w) + if ymin > ymax: + ymin, ymax = ymax, ymin + if xmin > xmax: + xmin, xmax = xmax, xmin + return [ymin, xmin, ymax, xmax] + + +def build_elements(regions: list) -> list: + elements = [] + for region in regions: + if not isinstance(region, dict): + continue + etype = "text" if region.get("type") == "text" else "obj" + element = {"type": etype} + element["bbox"] = _norm_bbox(region) + if etype == "text": + element["text"] = region.get("text", "") + element["desc"] = region.get("desc", "") + palette = normalize_palette(region.get("palette", [])) + if palette: + element["color_palette"] = palette[:5] + elements.append(element) + return elements + + +class CreateBoundingBoxes(io.ComfyNode): + @classmethod + def define_schema(cls): + editor_state = io.BoundingBoxes.Input( + "editor_state", + socketless=False, + tooltip="Draw bounding boxes and set each box type, text, description, color palette. Start with background element first and foreground last.", + ) + return io.Schema( + node_id="CreateBoundingBoxes", + display_name="Create Bounding Boxes", + category="utilities", + description="Draw bounding boxes in a canvas. Outputs Ideogram prompt elements, pixel-space bounding boxes, and a preview image.", + inputs=[ + io.Image.Input( + "background", + optional=True, + tooltip="Optional image used as background in the canvas and preview.", + ), + io.Int.Input("width", default=1024, min=64, max=16384, step=16, + tooltip="Width of the canvas and the pixel grid for the bounding boxes."), + io.Int.Input("height", default=1024, min=64, max=16384, step=16, + tooltip="Height of the canvas and the pixel grid for the bounding boxes."), + editor_state, + ], + outputs=[ + io.Image.Output(display_name="preview"), + io.BoundingBox.Output(display_name="bboxes"), + io.Array.Output(display_name="elements"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, width, height, editor_state=None, background=None) -> io.NodeOutput: + regions = boxes_to_regions(editor_state, width, height) + preview = render_preview(regions, width, height, _bg_from_image(background)) + return io.NodeOutput( + preview, + fractions_to_bbox_frame(regions, width, height), + build_elements(regions), + ui={"dims": [width, height]}, + ) + + +class BoundingBoxesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [CreateBoundingBoxes] + + +async def comfy_entrypoint() -> BoundingBoxesExtension: + return BoundingBoxesExtension() diff --git a/comfy_extras/nodes_color.py b/comfy_extras/nodes_color.py index 688254e4e..f58e51bff 100644 --- a/comfy_extras/nodes_color.py +++ b/comfy_extras/nodes_color.py @@ -1,5 +1,6 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, io +from comfy_extras.color_util import hex_to_rgb class ColorToRGBInt(io.ComfyNode): @@ -24,9 +25,11 @@ class ColorToRGBInt(io.ComfyNode): # expect format #RRGGBB if len(color) != 7 or color[0] != "#": raise ValueError("Color must be in format #RRGGBB") - r = int(color[1:3], 16) - g = int(color[3:5], 16) - b = int(color[5:7], 16) + try: + int(color[1:], 16) + except ValueError: + raise ValueError("Color must be in format #RRGGBB") from None + r, g, b = hex_to_rgb(color) rgb_int = r * 256 * 256 + g * 256 + b return io.NodeOutput(rgb_int, color) diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py index b745a43af..c8091b7a4 100644 --- a/comfy_extras/nodes_cond.py +++ b/comfy_extras/nodes_cond.py @@ -8,7 +8,8 @@ class CLIPTextEncodeControlnet(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="CLIPTextEncodeControlnet", - category="experimental/conditioning", + display_name="CLIP Text Encode (Controlnet)", + category="model/conditioning", inputs=[ io.Clip.Input("clip"), io.Conditioning.Input("conditioning"), @@ -35,11 +36,12 @@ class T5TokenizerOptions(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="T5TokenizerOptions", - category="experimental/conditioning", + display_name="T5 Tokenizer Options", + category="model/conditioning", inputs=[ io.Clip.Input("clip"), - io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True), - io.Int.Input("min_length", default=0, min=0, max=10000, step=1, advanced=True), + io.Int.Input("min_padding", default=0, min=0, max=10000, step=1), + io.Int.Input("min_length", default=0, min=0, max=10000, step=1), ], outputs=[io.Clip.Output()], is_experimental=True, diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 098c26f23..15d2dc506 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -13,21 +13,22 @@ class ContextWindowsManualNode(io.ComfyNode): description="Manually set context windows.", inputs=[ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), - io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True), - io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True), + io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."), + io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."), io.Combo.Input("context_schedule", options=[ comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.BATCHED, - ], tooltip="The stride of the context window."), - io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), + ], default=comfy.context_windows.ContextSchedules.STATIC_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window. For concat-style I2V models (e.g. Wan I2V, HunyuanVideo I2V, Cosmos I2V, SVD) the encoded start image lives in the c_concat conditioning channels; setting this to '0' will retain that start image content at sub-pos 0 of every window."), io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.String.Input("latent_retain_index_list", default="", tooltip="List of latent indices to retain in the noise latent itself for each window. Use for workflows where reference content (e.g. a start image) lives directly in the noise latent rather than in separate conditioning channels (e.g. inplace-style I2V like LTXV, AnimateDiff). Independent of cond_retain_index_list."), io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."), ], outputs=[ @@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode): @classmethod def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, - cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model: + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, latent_retain_index_list: list[int]=[], causal_window_fix: bool=True) -> io.Model: model = model.clone() model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), @@ -51,6 +52,7 @@ class ContextWindowsManualNode(io.ComfyNode): freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows, + latent_retain_index_list=latent_retain_index_list, causal_window_fix=causal_window_fix, ) # make memory usage calculation only take into account the context window latents @@ -65,33 +67,71 @@ class WanContextWindowsManualNode(ContextWindowsManualNode): schema = super().define_schema() schema.node_id = "WanContextWindowsManual" schema.display_name = "WAN Context Windows (Manual)" - schema.description = "Manually set context windows for WAN-like models (dim=2)." + schema.display_name = "Wan Context Windows" + schema.description = "Set context windows for Wan-like models." schema.category="model/patch/wan" schema.inputs = [ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), - io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True), - io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window in real frames. Must be 4*n + 1."), + io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window in real frames."), io.Combo.Input("context_schedule", options=[ comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.BATCHED, - ], tooltip="The stride of the context window."), + ], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."), io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), - io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), - io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), - #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True), + io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first I2V frame in every context window (may help retain initial reference)."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True), ] return schema @classmethod def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, - cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: - context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 - context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 - return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows) + retain_first_frame: bool=False, split_conds_to_windows: bool=False) -> io.Model: + context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 + context_overlap = max(context_overlap // 4, 0) # at least overlap 0 + retain_index_list = "0" if retain_first_frame else "" + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows) + + +class LTXVContextWindowsNode(ContextWindowsManualNode): + @classmethod + def define_schema(cls) -> io.Schema: + schema = super().define_schema() + schema.node_id = "LTXVContextWindows" + schema.display_name = "LTXV Context Windows" + schema.description = "Set context windows for LTXV-like models." + schema.inputs = [ + io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=8, default=145, tooltip="The length of the context window in real frames. Must be 8*n + 1."), + io.Int.Input("context_overlap", min=0, step=8, default=40, tooltip="The overlap of the context window in real frames."), + io.Combo.Input("context_schedule", options=[ + comfy.context_windows.ContextSchedules.STATIC_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, + comfy.context_windows.ContextSchedules.BATCHED, + ], default=comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, tooltip="Step-dependent scheduling algorithm for context windows."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules.", advanced=True), + io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Boolean.Input("freenoise", default=True, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending.", advanced=True), + io.Boolean.Input("retain_first_frame", default=False, tooltip="Retain the first latent frame in every context window (may help retain initial reference)."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index.", advanced=True), + ] + return schema + + @classmethod + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, fuse_method: str, freenoise: bool, + retain_first_frame: bool=False, split_conds_to_windows: bool=False, context_stride: int=1, closed_loop: bool=False) -> io.Model: + context_length = max(((context_length - 1) // 8) + 1, 1) # at least length 1 + context_overlap = max(context_overlap // 8, 0) # at least overlap 0 + retain_index_list = "0" if retain_first_frame else "" + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, + cond_retain_index_list=retain_index_list, latent_retain_index_list=retain_index_list, split_conds_to_windows=split_conds_to_windows) class ContextWindowsExtension(ComfyExtension): @@ -99,6 +139,7 @@ class ContextWindowsExtension(ComfyExtension): return [ ContextWindowsManualNode, WanContextWindowsManualNode, + LTXVContextWindowsNode, ] def comfy_entrypoint(): diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index c9d7e06fc..56ef5f526 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -1070,7 +1070,7 @@ class AddNoise(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="AddNoise", - category="experimental/custom_sampling/noise", + category="model/sampling/noise", is_experimental=True, inputs=[ io.Model.Input("model"), @@ -1120,7 +1120,7 @@ class ManualSigmas(io.ComfyNode): return io.Schema( node_id="ManualSigmas", search_aliases=["custom noise schedule", "define sigmas"], - category="experimental/custom_sampling", + category="model/sampling/sigmas", is_experimental=True, inputs=[ io.String.Input("sigmas", default="1, 0.5", multiline=False) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 0253b4b4f..73fe75b7f 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1583,7 +1583,7 @@ class LoadTrainingDataset(io.ComfyNode): shard_path = os.path.join(dataset_dir, shard_file) with open(shard_path, "rb") as f: - shard_data = torch.load(f) + shard_data = torch.load(f, weights_only=True) all_latents.extend(shard_data["latents"]) all_conditioning.extend(shard_data["conditioning"]) diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index 4d5bca17e..44708e5ec 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -77,7 +77,7 @@ class FrameInterpolate(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="FrameInterpolate", - display_name="Frame Interpolate", + display_name="Run Frame Interpolation Model", category="video", search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], inputs=[ diff --git a/comfy_extras/nodes_glsl.py b/comfy_extras/nodes_glsl.py index ea7420a73..c7161973a 100644 --- a/comfy_extras/nodes_glsl.py +++ b/comfy_extras/nodes_glsl.py @@ -1,85 +1,68 @@ import os import sys import re +import ctypes import logging -import ctypes.util -import importlib.util from typing import TypedDict import numpy as np import torch import nodes +import comfy_angle from comfy_api.latest import ComfyExtension, io, ui from typing_extensions import override -from utils.install_util import get_missing_requirements_message logger = logging.getLogger(__name__) -def _check_opengl_availability(): - """Early check for OpenGL availability. Raises RuntimeError if unlikely to work.""" - logger.debug("_check_opengl_availability: starting") - missing = [] +def _preload_angle(): + egl_path = comfy_angle.get_egl_path() + gles_path = comfy_angle.get_glesv2_path() - # Check Python packages (using find_spec to avoid importing) - logger.debug("_check_opengl_availability: checking for glfw package") - if importlib.util.find_spec("glfw") is None: - missing.append("glfw") + if sys.platform == "win32": + angle_dir = comfy_angle.get_lib_dir() + os.add_dll_directory(angle_dir) + os.environ["PATH"] = angle_dir + os.pathsep + os.environ.get("PATH", "") - logger.debug("_check_opengl_availability: checking for OpenGL package") - if importlib.util.find_spec("OpenGL") is None: - missing.append("PyOpenGL") - - if missing: - raise RuntimeError( - f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n" - ) - - # On Linux without display, check if headless backends are available - logger.debug(f"_check_opengl_availability: platform={sys.platform}") - if sys.platform.startswith("linux"): - has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY") - logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}") - if not has_display: - # Check for EGL or OSMesa libraries - logger.debug("_check_opengl_availability: checking for EGL library") - has_egl = ctypes.util.find_library("EGL") - logger.debug("_check_opengl_availability: checking for OSMesa library") - has_osmesa = ctypes.util.find_library("OSMesa") - - # Error disabled for CI as it fails this check - # if not has_egl and not has_osmesa: - # raise RuntimeError( - # "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n" - # "See error below for installation instructions." - # ) - logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}") - - logger.debug("_check_opengl_availability: completed") + mode = 0 if sys.platform == "win32" else ctypes.RTLD_GLOBAL + ctypes.CDLL(str(egl_path), mode=mode) + ctypes.CDLL(str(gles_path), mode=mode) -# Run early check at import time -logger.debug("nodes_glsl: running _check_opengl_availability at import time") -_check_opengl_availability() - -# OpenGL modules - initialized lazily when context is created -gl = None -glfw = None -EGL = None +# Pre-load ANGLE *before* any PyOpenGL import so that the EGL platform +# plugin picks up ANGLE's libEGL / libGLESv2 instead of system libs. +_preload_angle() +os.environ.setdefault("PYOPENGL_PLATFORM", "egl") -def _import_opengl(): - """Import OpenGL module. Called after context is created.""" - global gl - if gl is None: - logger.debug("_import_opengl: importing OpenGL.GL") - import OpenGL.GL as _gl - gl = _gl - logger.debug("_import_opengl: import completed") - return gl +import OpenGL +OpenGL.USE_ACCELERATE = False +def _patch_find_library(): + """PyOpenGL's EGL platform looks for 'EGL' and 'GLESv2' by short name + via ctypes.util.find_library, but ANGLE ships as 'libEGL' and + 'libGLESv2'. Patch find_library to return the full ANGLE paths so + PyOpenGL loads the same libraries we pre-loaded.""" + if sys.platform == "linux": + return + import ctypes.util + _orig = ctypes.util.find_library + def _patched(name): + if name == 'EGL': + return comfy_angle.get_egl_path() + if name == 'GLESv2': + return comfy_angle.get_glesv2_path() + return _orig(name) + ctypes.util.find_library = _patched + + +_patch_find_library() + +from OpenGL import EGL +from OpenGL import GLES3 as gl + class SizeModeInput(TypedDict): size_mode: str width: int @@ -102,7 +85,7 @@ MAX_OUTPUTS = 4 # fragColor0-3 (MRT) # (-1,-1)---(3,-1) # # v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1) -VERTEX_SHADER = """#version 330 core +VERTEX_SHADER = """#version 300 es out vec2 v_texCoord; void main() { vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3)); @@ -126,14 +109,99 @@ void main() { """ -def _convert_es_to_desktop(source: str) -> str: - """Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core.""" - # Remove any existing #version directive - source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE) - # Remove precision qualifiers (not needed in desktop GLSL) - source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source) - # Prepend desktop GLSL version - return "#version 330 core\n" + source + +def _egl_attribs(*values): + """Build an EGL_NONE-terminated EGLint attribute array.""" + vals = list(values) + [EGL.EGL_NONE] + return (ctypes.c_int32 * len(vals))(*vals) + + +# EGL platform extension constants +EGL_PLATFORM_ANGLE_ANGLE = 0x3202 +EGL_PLATFORM_ANGLE_TYPE_ANGLE = 0x3203 +EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE = 0x3450 +EGL_MESA_PLATFORM_SURFACELESS = 0x31DD + + +_eglGetPlatformDisplayEXT = None + +def _get_egl_platform_display_ext(platform, native_display, attribs): + """Call eglGetPlatformDisplayEXT via ctypes (extension, not in PyOpenGL).""" + global _eglGetPlatformDisplayEXT + if _eglGetPlatformDisplayEXT is None: + from OpenGL import platform as _plat + egl_lib = _plat.PLATFORM.EGL + _get_proc = egl_lib.eglGetProcAddress + _get_proc.restype = ctypes.c_void_p + _get_proc.argtypes = [ctypes.c_char_p] + ptr = _get_proc(b"eglGetPlatformDisplayEXT") + if not ptr: + return None + func_type = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_uint32, ctypes.c_void_p, ctypes.c_void_p) + _eglGetPlatformDisplayEXT = func_type(ptr) + + raw = _eglGetPlatformDisplayEXT(platform, native_display, attribs) + if not raw: + return None + return ctypes.cast(raw, EGL.EGLDisplay) + + +def _get_egl_display(): + """Get an EGL display, trying the default first then ANGLE's Vulkan + platform for headless environments without a display server.""" + failures = [] + + # Try the default display first (works when X11/Wayland is available) + display = EGL.eglGetDisplay(EGL.EGL_DEFAULT_DISPLAY) + if display: + major, minor = ctypes.c_int32(0), ctypes.c_int32(0) + try: + if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)): + return display, major.value, minor.value + except Exception as e: + failures.append(f"default: {e}") + + logger.info("Default EGL display unavailable, trying headless fallbacks") + + # Headless fallback strategies, tried in order: + headless_strategies = [ + ("surfaceless", EGL_MESA_PLATFORM_SURFACELESS, None, None), + ("ANGLE Vulkan", EGL_PLATFORM_ANGLE_ANGLE, None, + _egl_attribs(EGL_PLATFORM_ANGLE_TYPE_ANGLE, EGL_PLATFORM_ANGLE_TYPE_VULKAN_ANGLE)), + ] + + for name, platform, native_display, attribs in headless_strategies: + display = _get_egl_platform_display_ext(platform, native_display, attribs) + if not display: + failures.append(f"{name}: eglGetPlatformDisplayEXT returned no display") + continue + major, minor = ctypes.c_int32(0), ctypes.c_int32(0) + try: + if EGL.eglInitialize(display, ctypes.byref(major), ctypes.byref(minor)): + logger.info(f"Using EGL {name} platform (headless)") + return display, major.value, minor.value + failures.append(f"{name}: eglInitialize returned false") + except Exception as e: + failures.append(f"{name}: {e}") + continue + + details = "\n".join(f" - {f}" for f in failures) + raise RuntimeError( + "Failed to initialize EGL display.\n" + "No display server and no headless EGL platform available.\n" + f"Tried:\n{details}\n" + "Ensure GPU drivers are installed or set DISPLAY for a virtual framebuffer." + ) + + +def _gl_str(name): + """Get an OpenGL string parameter.""" + v = gl.glGetString(name) + if not v: + return "Unknown" + if isinstance(v, bytes): + return v.decode(errors="replace") + return ctypes.string_at(v).decode(errors="replace") def _detect_output_count(source: str) -> int: @@ -159,163 +227,8 @@ def _detect_pass_count(source: str) -> int: return 1 -def _init_glfw(): - """Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure.""" - logger.debug("_init_glfw: starting") - # On macOS, glfw.init() must be called from main thread or it hangs forever - if sys.platform == "darwin": - logger.debug("_init_glfw: skipping on macOS") - raise RuntimeError("GLFW backend not supported on macOS") - - logger.debug("_init_glfw: importing glfw module") - import glfw as _glfw - - logger.debug("_init_glfw: calling glfw.init()") - if not _glfw.init(): - raise RuntimeError("glfw.init() failed") - - try: - logger.debug("_init_glfw: setting window hints") - _glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE) - _glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3) - _glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3) - _glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE) - - logger.debug("_init_glfw: calling create_window()") - window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None) - if not window: - raise RuntimeError("glfw.create_window() failed") - - logger.debug("_init_glfw: calling make_context_current()") - _glfw.make_context_current(window) - logger.debug("_init_glfw: completed successfully") - return window, _glfw - except Exception: - logger.debug("_init_glfw: failed, terminating glfw") - _glfw.terminate() - raise - - -def _init_egl(): - """Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure.""" - logger.debug("_init_egl: starting") - from OpenGL import EGL as _EGL - from OpenGL.EGL import ( - eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext, - eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI, - eglTerminate, eglDestroyContext, eglDestroySurface, - EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE, - EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, - EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE, - EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API, - ) - logger.debug("_init_egl: imports completed") - - display = None - context = None - surface = None - - try: - logger.debug("_init_egl: calling eglGetDisplay()") - display = eglGetDisplay(EGL_DEFAULT_DISPLAY) - if display == _EGL.EGL_NO_DISPLAY: - raise RuntimeError("eglGetDisplay() failed") - - logger.debug("_init_egl: calling eglInitialize()") - major, minor = _EGL.EGLint(), _EGL.EGLint() - if not eglInitialize(display, major, minor): - display = None # Not initialized, don't terminate - raise RuntimeError("eglInitialize() failed") - logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}") - - config_attribs = [ - EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, - EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, - EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8, - EGL_DEPTH_SIZE, 0, EGL_NONE - ] - configs = (_EGL.EGLConfig * 1)() - num_configs = _EGL.EGLint() - if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0: - raise RuntimeError("eglChooseConfig() failed") - config = configs[0] - logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}") - - if not eglBindAPI(EGL_OPENGL_API): - raise RuntimeError("eglBindAPI() failed") - - logger.debug("_init_egl: calling eglCreateContext()") - context_attribs = [ - _EGL.EGL_CONTEXT_MAJOR_VERSION, 3, - _EGL.EGL_CONTEXT_MINOR_VERSION, 3, - _EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT, - EGL_NONE - ] - context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs) - if context == EGL_NO_CONTEXT: - raise RuntimeError("eglCreateContext() failed") - - logger.debug("_init_egl: calling eglCreatePbufferSurface()") - pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE] - surface = eglCreatePbufferSurface(display, config, pbuffer_attribs) - if surface == _EGL.EGL_NO_SURFACE: - raise RuntimeError("eglCreatePbufferSurface() failed") - - logger.debug("_init_egl: calling eglMakeCurrent()") - if not eglMakeCurrent(display, surface, surface, context): - raise RuntimeError("eglMakeCurrent() failed") - - logger.debug("_init_egl: completed successfully") - return display, context, surface, _EGL - - except Exception: - logger.debug("_init_egl: failed, cleaning up") - # Clean up any resources on failure - if surface is not None: - eglDestroySurface(display, surface) - if context is not None: - eglDestroyContext(display, context) - if display is not None: - eglTerminate(display) - raise - - -def _init_osmesa(): - """Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure.""" - import ctypes - - logger.debug("_init_osmesa: starting") - os.environ["PYOPENGL_PLATFORM"] = "osmesa" - - logger.debug("_init_osmesa: importing OpenGL.osmesa") - from OpenGL import GL as _gl - from OpenGL.osmesa import ( - OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext, - OSMESA_RGBA, - ) - logger.debug("_init_osmesa: imports completed") - - ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None) - if not ctx: - raise RuntimeError("OSMesaCreateContextExt() failed") - - width, height = 64, 64 - buffer = (ctypes.c_ubyte * (width * height * 4))() - - logger.debug("_init_osmesa: calling OSMesaMakeCurrent()") - if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height): - OSMesaDestroyContext(ctx) - raise RuntimeError("OSMesaMakeCurrent() failed") - - logger.debug("_init_osmesa: completed successfully") - return ctx, buffer - - class GLContext: - """Manages OpenGL context and resources for shader execution. - - Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software). - """ + """Manages an OpenGL ES 3.0 context via EGL/ANGLE (singleton).""" _instance = None _initialized = False @@ -327,131 +240,105 @@ class GLContext: def __init__(self): if GLContext._initialized: - logger.debug("GLContext.__init__: already initialized, skipping") return - logger.debug("GLContext.__init__: starting initialization") - - global glfw, EGL - import time start = time.perf_counter() - self._backend = None - self._window = None - self._egl_display = None - self._egl_context = None - self._egl_surface = None - self._osmesa_ctx = None - self._osmesa_buffer = None + self._display = None + self._surface = None + self._context = None self._vao = None - # Try backends in order: GLFW → EGL → OSMesa - errors = [] - - logger.debug("GLContext.__init__: trying GLFW backend") try: - self._window, glfw = _init_glfw() - self._backend = "glfw" - logger.debug("GLContext.__init__: GLFW backend succeeded") - except Exception as e: - logger.debug(f"GLContext.__init__: GLFW backend failed: {e}") - errors.append(("GLFW", e)) + self._display, self._egl_major, self._egl_minor = _get_egl_display() - if self._backend is None: - logger.debug("GLContext.__init__: trying EGL backend") - try: - self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl() - self._backend = "egl" - logger.debug("GLContext.__init__: EGL backend succeeded") - except Exception as e: - logger.debug(f"GLContext.__init__: EGL backend failed: {e}") - errors.append(("EGL", e)) + if not EGL.eglBindAPI(EGL.EGL_OPENGL_ES_API): + raise RuntimeError("eglBindAPI(EGL_OPENGL_ES_API) failed") - if self._backend is None: - logger.debug("GLContext.__init__: trying OSMesa backend") - try: - self._osmesa_ctx, self._osmesa_buffer = _init_osmesa() - self._backend = "osmesa" - logger.debug("GLContext.__init__: OSMesa backend succeeded") - except Exception as e: - logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}") - errors.append(("OSMesa", e)) + config = EGL.EGLConfig() + n_configs = ctypes.c_int32(0) + if not EGL.eglChooseConfig( + self._display, + _egl_attribs( + EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_ES3_BIT, + EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT, + EGL.EGL_RED_SIZE, 8, EGL.EGL_GREEN_SIZE, 8, + EGL.EGL_BLUE_SIZE, 8, EGL.EGL_ALPHA_SIZE, 8, + ), + ctypes.byref(config), 1, ctypes.byref(n_configs), + ) or n_configs.value == 0: + raise RuntimeError("eglChooseConfig() failed") - if self._backend is None: - if sys.platform == "win32": - platform_help = ( - "Windows: Ensure GPU drivers are installed and display is available.\n" - " CPU-only/headless mode is not supported on Windows." - ) - elif sys.platform == "darwin": - platform_help = ( - "macOS: GLFW is not supported.\n" - " Install OSMesa via Homebrew: brew install mesa\n" - " Then: pip install PyOpenGL PyOpenGL-accelerate" - ) - else: - platform_help = ( - "Linux: Install one of these backends:\n" - " Desktop: sudo apt install libgl1-mesa-glx libglfw3\n" - " Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n" - " Headless (CPU): sudo apt install libosmesa6" - ) - - error_details = "\n".join(f" {name}: {err}" for name, err in errors) - raise RuntimeError( - f"Failed to create OpenGL context.\n\n" - f"Backend errors:\n{error_details}\n\n" - f"{platform_help}" + self._surface = EGL.eglCreatePbufferSurface( + self._display, config, + _egl_attribs(EGL.EGL_WIDTH, 64, EGL.EGL_HEIGHT, 64), ) + if not self._surface: + raise RuntimeError("eglCreatePbufferSurface() failed") - # Now import OpenGL.GL (after context is current) - logger.debug("GLContext.__init__: importing OpenGL.GL") - _import_opengl() + self._context = EGL.eglCreateContext( + self._display, config, EGL.EGL_NO_CONTEXT, + _egl_attribs(EGL.EGL_CONTEXT_CLIENT_VERSION, 3), + ) + if not self._context: + raise RuntimeError("eglCreateContext() failed") - # Create VAO (required for core profile, but OSMesa may use compat profile) - logger.debug("GLContext.__init__: creating VAO") - try: - vao = gl.glGenVertexArrays(1) - gl.glBindVertexArray(vao) - self._vao = vao # Only store after successful bind - logger.debug("GLContext.__init__: VAO created successfully") - except Exception as e: - logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}") - # OSMesa with older Mesa may not support VAOs - # Clean up if we created but couldn't bind - if vao: - try: - gl.glDeleteVertexArrays(1, [vao]) - except Exception: - pass + if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context): + raise RuntimeError("eglMakeCurrent() failed") + + self._vao = gl.glGenVertexArrays(1) + gl.glBindVertexArray(self._vao) + + except Exception: + self._cleanup() + raise elapsed = (time.perf_counter() - start) * 1000 - # Log device info - renderer = gl.glGetString(gl.GL_RENDERER) - vendor = gl.glGetString(gl.GL_VENDOR) - version = gl.glGetString(gl.GL_VERSION) - renderer = renderer.decode() if renderer else "Unknown" - vendor = vendor.decode() if vendor else "Unknown" - version = version.decode() if version else "Unknown" + renderer = _gl_str(gl.GL_RENDERER) + vendor = _gl_str(gl.GL_VENDOR) + version = _gl_str(gl.GL_VERSION) GLContext._initialized = True - logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}") + logger.info(f"GLSL context initialized in {elapsed:.1f}ms - EGL {self._egl_major}.{self._egl_minor}, {renderer} ({vendor}), GL {version}") def make_current(self): - if self._backend == "glfw": - glfw.make_context_current(self._window) - elif self._backend == "egl": - from OpenGL.EGL import eglMakeCurrent - eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context) - elif self._backend == "osmesa": - from OpenGL.osmesa import OSMesaMakeCurrent - OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64) - + if not EGL.eglMakeCurrent(self._display, self._surface, self._surface, self._context): + err = EGL.eglGetError() + raise RuntimeError(f"eglMakeCurrent() failed (EGL error: 0x{err:04X})") if self._vao is not None: gl.glBindVertexArray(self._vao) + def _cleanup(self): + if not self._display: + return + try: + if self._vao is not None: + gl.glDeleteVertexArrays(1, [self._vao]) + self._vao = None + except Exception: + pass + try: + EGL.eglMakeCurrent(self._display, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT) + except Exception: + pass + try: + if self._context: + EGL.eglDestroyContext(self._display, self._context) + except Exception: + pass + try: + if self._surface: + EGL.eglDestroySurface(self._display, self._surface) + except Exception: + pass + try: + EGL.eglTerminate(self._display) + except Exception: + pass + self._display = None + def _compile_shader(source: str, shader_type: int) -> int: """Compile a shader and return its ID.""" @@ -459,8 +346,10 @@ def _compile_shader(source: str, shader_type: int) -> int: gl.glShaderSource(shader, source) gl.glCompileShader(shader) - if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE: - error = gl.glGetShaderInfoLog(shader).decode() + if not gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS): + error = gl.glGetShaderInfoLog(shader) + if isinstance(error, bytes): + error = error.decode(errors="replace") gl.glDeleteShader(shader) raise RuntimeError(f"Shader compilation failed:\n{error}") @@ -484,8 +373,10 @@ def _create_program(vertex_source: str, fragment_source: str) -> int: gl.glDeleteShader(vertex_shader) gl.glDeleteShader(fragment_shader) - if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE: - error = gl.glGetProgramInfoLog(program).decode() + if not gl.glGetProgramiv(program, gl.GL_LINK_STATUS): + error = gl.glGetProgramInfoLog(program) + if isinstance(error, bytes): + error = error.decode(errors="replace") gl.glDeleteProgram(program) raise RuntimeError(f"Program linking failed:\n{error}") @@ -530,9 +421,6 @@ def _render_shader_batch( ctx = GLContext() ctx.make_current() - # Convert from GLSL ES to desktop GLSL 330 - fragment_source = _convert_es_to_desktop(fragment_code) - # Detect how many outputs the shader actually uses num_outputs = _detect_output_count(fragment_code) @@ -558,9 +446,9 @@ def _render_shader_batch( try: # Compile shaders (once for all batches) try: - program = _create_program(VERTEX_SHADER, fragment_source) + program = _create_program(VERTEX_SHADER, fragment_code) except RuntimeError: - logger.error(f"Fragment shader:\n{fragment_source}") + logger.error(f"Fragment shader:\n{fragment_code}") raise gl.glUseProgram(program) @@ -723,13 +611,13 @@ def _render_shader_batch( gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3) # Read back outputs for this batch - # (glGetTexImage is synchronous, implicitly waits for rendering) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo) batch_outputs = [] - for tex in output_textures: - gl.glBindTexture(gl.GL_TEXTURE_2D, tex) - data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT) - img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4) - batch_outputs.append(img[::-1, :, :].copy()) + for i in range(num_outputs): + gl.glReadBuffer(gl.GL_COLOR_ATTACHMENT0 + i) + buf = np.empty((height, width, 4), dtype=np.float32) + gl.glReadPixels(0, 0, width, height, gl.GL_RGBA, gl.GL_FLOAT, buf) + batch_outputs.append(buf[::-1, :, :].copy()) # Pad with black images for unused outputs black_img = np.zeros((height, width, 4), dtype=np.float32) @@ -750,18 +638,18 @@ def _render_shader_batch( gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) gl.glUseProgram(0) - for tex in input_textures: - gl.glDeleteTextures(int(tex)) - for tex in curve_textures: - gl.glDeleteTextures(int(tex)) - for tex in output_textures: - gl.glDeleteTextures(int(tex)) - for tex in ping_pong_textures: - gl.glDeleteTextures(int(tex)) + if input_textures: + gl.glDeleteTextures(len(input_textures), input_textures) + if curve_textures: + gl.glDeleteTextures(len(curve_textures), curve_textures) + if output_textures: + gl.glDeleteTextures(len(output_textures), output_textures) + if ping_pong_textures: + gl.glDeleteTextures(len(ping_pong_textures), ping_pong_textures) if fbo is not None: gl.glDeleteFramebuffers(1, [fbo]) - for pp_fbo in ping_pong_fbos: - gl.glDeleteFramebuffers(1, [pp_fbo]) + if ping_pong_fbos: + gl.glDeleteFramebuffers(len(ping_pong_fbos), ping_pong_fbos) if program is not None: gl.glDeleteProgram(program) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 469a7be55..fe1937ba5 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -214,11 +214,13 @@ class SaveAnimatedWEBP(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Image.Output(display_name="images")] ) @classmethod def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput: return IO.NodeOutput( + images, ui=UI.ImageSaveHelper.get_save_animated_webp_ui( images=images, filename_prefix=filename_prefix, @@ -230,8 +232,6 @@ class SaveAnimatedWEBP(IO.ComfyNode): ) ) - save_images = execute # TODO: remove - class SaveAnimatedPNG(IO.ComfyNode): @@ -249,11 +249,13 @@ class SaveAnimatedPNG(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Image.Output(display_name="images")] ) @classmethod def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput: return IO.NodeOutput( + images, ui=UI.ImageSaveHelper.get_save_animated_png_ui( images=images, filename_prefix=filename_prefix, @@ -263,8 +265,6 @@ class SaveAnimatedPNG(IO.ComfyNode): ) ) - save_images = execute # TODO: remove - class ImageStitch(IO.ComfyNode): """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" @@ -513,6 +513,7 @@ class SaveSVGNode(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.SVG.Output("svg")], ) @classmethod @@ -562,9 +563,7 @@ class SaveSVGNode(IO.ComfyNode): results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output)) counter += 1 - return IO.NodeOutput(ui={"images": results}) - - save_svg = execute # TODO: remove + return IO.NodeOutput(svg, ui={"images": results}) class GetImageSize(IO.ComfyNode): @@ -1157,40 +1156,27 @@ class SaveImageAdvanced(IO.ComfyNode): IO.String.Input( "filename_prefix", default="ComfyUI", - tooltip=( - "The prefix for the file to save. May include formatting tokens " - "such as %date:yyyy-MM-dd% or %Empty Latent Image.width%." - ), + tooltip=("The prefix for the file to save. May include formatting tokens such as %date:yyyy-MM-dd% or %Empty Latent Image.width%."), ), IO.DynamicCombo.Input( "format", options=[ IO.DynamicCombo.Option("png", [ - IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"], - default="8-bit", advanced=True), - IO.Combo.Input("input_color_space", options=["sRGB"], - default="sRGB", advanced=True), + IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"], default="8-bit", advanced=True), + IO.Combo.Input("input_color_space", options=["sRGB"], default="sRGB", advanced=True), ]), IO.DynamicCombo.Option("exr", [ - IO.Combo.Input("bit_depth", options=["32-bit float"], - default="32-bit float", advanced=True), + IO.Combo.Input("bit_depth", options=["32-bit float"], default="32-bit float", advanced=True), IO.Combo.Input( "input_color_space", options=["sRGB", "HDR", "linear"], default="sRGB", advanced=True, tooltip=( - "Colorspace of the input tensor. The EXR is " - "always written as scene-linear in the matching " - "gamut.\n" - " 'sRGB' — input is sRGB-encoded Rec.709; " - "the inverse sRGB EOTF is applied.\n" - " 'HDR' — input is HLG-encoded Rec.2020 " - "(BT.2100); the inverse HLG OETF is applied " - "to get scene-linear light.\n" - " 'linear' — input is already scene-linear " - "(Rec.709 primaries); written through unchanged. " - "Use this for renderer/compositor output." + "Colorspace of the input tensor. The EXR is always written as scene-linear in the matching gamut.\n" + "sRGB — input is sRGB-encoded Rec.709; the inverse sRGB EOTF is applied.\n" + "HDR — input is HLG-encoded Rec.2020 (BT.2100); the inverse HLG OETF is applied to get scene-linear light.\n" + "linear — input is already scene-linear (Rec.709 primaries); written through unchanged. Use this for renderer/compositor output." ), ), ]), @@ -1200,6 +1186,7 @@ class SaveImageAdvanced(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Image.Output(display_name="images")] ) @classmethod @@ -1237,7 +1224,7 @@ class SaveImageAdvanced(IO.ComfyNode): results.append({"filename": file, "subfolder": subfolder, "type": "output"}) counter += 1 - return IO.NodeOutput(ui={"images": results}) + return IO.NodeOutput(images, ui={"images": results}) class ImagesExtension(ComfyExtension): diff --git a/comfy_extras/nodes_json_prompt.py b/comfy_extras/nodes_json_prompt.py new file mode 100644 index 000000000..206f5aa71 --- /dev/null +++ b/comfy_extras/nodes_json_prompt.py @@ -0,0 +1,77 @@ +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io +from comfy_extras.color_util import normalize_palette + + +class BuildJsonPromptIdeogram(io.ComfyNode): + @classmethod + def define_schema(cls): + color_palette = io.Colors.Input( + "color_palette", + socketless=False, + tooltip="Hex color codes that steer the image's dominant colors. Up to 16 entries.", + ) + return io.Schema( + node_id="BuildJsonPromptIdeogram", + display_name="Build JSON Prompt (Ideogram)", + category="text", + description="Build a JSON prompt for the Ideogram 4 model.", + inputs=[ + io.Array.Input("element", tooltip="Prompt elements from the node Create Bounding Boxes."), + io.String.Input("high_level_description", multiline=True, default="", + tooltip="Optional description of the image in one or two sentences. Strongly recommended."), + io.String.Input("background", multiline=True, default="", + tooltip="Mandatory description of the image background or environment."), + io.DynamicCombo.Input("style", options=[ + io.DynamicCombo.Option("none", []), + io.DynamicCombo.Option("photo", [io.String.Input("photo", default="", tooltip="Camera or lens details for photographic outputs (e.g. 35mm, f/1.4, bokeh).")]), + io.DynamicCombo.Option("art_style", [io.String.Input("art_style", default="", tooltip="Art style description (e.g. flat vector illustration, bold outlines).")]), + ]), + io.String.Input("aesthetics", default="", tooltip="Mandatory aesthetic keywords (e.g. moody, cinematic, desaturated)."), + io.String.Input("lighting", default="", tooltip="Mandatory lighting description (e.g. golden hour, rim light, dramatic shadows)."), + io.String.Input("medium", default="", tooltip="Mandatory medium type (e.g. photograph, illustration, 3d_render, painting, graphic_design). When style = photo, set to photograph."), + color_palette, + ], + outputs=[io.Dict.Output(display_name="prompt")], + is_experimental=True, + ) + + @classmethod + def execute(cls, element, style, high_level_description="", background="", + aesthetics="", lighting="", medium="", color_palette=None) -> io.NodeOutput: + elements = element if isinstance(element, list) else [] + kind = style.get("style", "none") if isinstance(style, dict) else "none" + photo = style.get("photo", "") if isinstance(style, dict) else "" + art_style = style.get("art_style", "") if isinstance(style, dict) else "" + palette = normalize_palette(color_palette or []) + + caption: dict = {} + if high_level_description.strip(): + caption["high_level_description"] = high_level_description + if kind != "none": + style_desc: dict = {"aesthetics": aesthetics, "lighting": lighting} + if kind == "photo": + style_desc["photo"] = photo + style_desc["medium"] = medium + else: + style_desc["medium"] = medium + style_desc["art_style"] = art_style + if palette: + style_desc["color_palette"] = palette + caption["style_description"] = style_desc + caption["compositional_deconstruction"] = { + "background": background, + "elements": elements, + } + return io.NodeOutput(caption) + + +class JsonPromptExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [BuildJsonPromptIdeogram] + + +async def comfy_entrypoint() -> JsonPromptExtension: + return JsonPromptExtension() diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 455897859..6e3e88471 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -317,11 +317,74 @@ class PreviewPointCloud(IO.ComfyNode): ) +MESH_EXTENSIONS = {'.gltf', '.glb', '.obj', '.fbx', '.stl'} + + +class Load3DAdvanced(IO.ComfyNode): + @classmethod + def define_schema(cls): + input_dir = os.path.join(folder_paths.get_input_directory(), "3d") + os.makedirs(input_dir, exist_ok=True) + + input_path = Path(input_dir) + base_path = Path(folder_paths.get_input_directory()) + + files = [ + normalize_path(str(file_path.relative_to(base_path))) + for file_path in input_path.rglob("*") + if file_path.suffix.lower() in MESH_EXTENSIONS + ] + return IO.Schema( + node_id="Load3DAdvanced", + display_name="Load 3D (Advanced)", + category="3d", + search_aliases=[ + "load mesh", + "load gltf", + "load glb", + "load obj", + "load fbx", + "load stl", + ], + is_experimental=True, + inputs=[ + IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model), + IO.Load3D.Input("viewport_state"), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.File3DAny.Output(display_name="model_3d"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + ], + ) + + @classmethod + def validate_inputs(cls, model_file, **kwargs) -> bool | str: + if not model_file or model_file == "none": + return True + if not folder_paths.exists_annotated_filepath(model_file): + return f"Invalid 3D model file: {model_file}" + return True + + @classmethod + def execute(cls, model_file, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: + file_3d = None + if model_file and model_file != "none": + file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file)) + model_3d_info = viewport_state.get('model_3d_info', []) + return IO.NodeOutput(file_3d, model_3d_info, viewport_state['camera_info'], width, height) + + class Load3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ Load3D, + Load3DAdvanced, Preview3D, Preview3DAdvanced, PreviewGaussianSplat, diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 95f6ab848..13c1685f7 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -89,7 +89,8 @@ class SwitchNode(io.ComfyNode): template = io.MatchType.Template("switch") return io.Schema( node_id="ComfySwitchNode", - display_name="Switch", + search_aliases=["if", "then", "switch", "conditional", "branch"], + display_name="If/Else Switch", category="utilities/logic", is_experimental=True, inputs=[ diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py index 2fa684b3a..e563d950b 100644 --- a/comfy_extras/nodes_model_merging_model_specific.py +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -337,6 +337,36 @@ class ModelMergeQwenImage(comfy_extras.nodes_model_merging.ModelMergeBlocks): return {"required": arg_dict} +class ModelMergeKrea2(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "model/merging/model specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["first."] = argument + arg_dict["tmlp."] = argument + arg_dict["txtmlp."] = argument + arg_dict["tproj."] = argument + + for i in range(2): + arg_dict["txtfusion.layerwise_blocks.{}.".format(i)] = argument + + arg_dict["txtfusion.projector."] = argument + + for i in range(2): + arg_dict["txtfusion.refiner_blocks.{}.".format(i)] = argument + + for i in range(28): + arg_dict["blocks.{}.".format(i)] = argument + + arg_dict["last."] = argument + + return {"required": arg_dict} + NODE_CLASS_MAPPINGS = { "ModelMergeSD1": ModelMergeSD1, "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks @@ -353,4 +383,5 @@ NODE_CLASS_MAPPINGS = { "ModelMergeCosmosPredict2_2B": ModelMergeCosmosPredict2_2B, "ModelMergeCosmosPredict2_14B": ModelMergeCosmosPredict2_14B, "ModelMergeQwenImage": ModelMergeQwenImage, + "ModelMergeKrea2": ModelMergeKrea2, } diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py index 8a2248572..72fad1673 100644 --- a/comfy_extras/nodes_photomaker.py +++ b/comfy_extras/nodes_photomaker.py @@ -123,7 +123,8 @@ class PhotoMakerLoader(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PhotoMakerLoader", - category="experimental/photomaker", + display_name="Load PhotoMaker Model", + category="model/loaders", inputs=[ io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")), ], @@ -149,7 +150,8 @@ class PhotoMakerEncode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PhotoMakerEncode", - category="experimental/photomaker", + display_name="PhotoMaker Encode", + category="model/conditioning/photomaker", inputs=[ io.Photomaker.Input("photomaker"), io.Image.Input("image"), diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index c44b09098..7f90daf14 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -10,12 +10,11 @@ class String(io.ComfyNode): return io.Schema( node_id="PrimitiveString", search_aliases=["text", "string", "text box", "prompt"], - display_name="Text String", + display_name="Text String (DEPRECATED)", category="utilities/primitive", - inputs=[ - io.String.Input("value"), - ], + inputs=[io.String.Input("value")], outputs=[io.String.Output()], + is_deprecated=True ) @classmethod @@ -29,12 +28,10 @@ class StringMultiline(io.ComfyNode): return io.Schema( node_id="PrimitiveStringMultiline", search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"], - display_name="Text String (Multiline)", + display_name="Input Text", category="utilities/primitive", essentials_category="Basics", - inputs=[ - io.String.Input("value", multiline=True), - ], + inputs=[io.String.Input("value", multiline=True)], outputs=[io.String.Output()], ) diff --git a/comfy_extras/nodes_scail.py b/comfy_extras/nodes_scail.py index 007733efc..55c9897e3 100644 --- a/comfy_extras/nodes_scail.py +++ b/comfy_extras/nodes_scail.py @@ -34,14 +34,20 @@ def _unpack(track_data): return unpack_masks(packed) -def _first_frame_cx_area(masks_bool): - first = masks_bool[0].float() - H, W = first.shape[-2], first.shape[-1] - n_pixels = H * W - grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W) - area = first.sum(dim=(-1, -2)).clamp_(min=1) - cx = (first * grid_x).sum(dim=(-1, -2)) / area - return (cx / W).tolist(), (area / n_pixels).tolist() +def _first_appearance_cx_area(masks_bool): + """Per object: first frame it appears in, plus centroid-x and area in that frame.""" + m = masks_bool.float() + T, H, W = m.shape[0], m.shape[-2], m.shape[-1] + grid_x = torch.arange(W, device=m.device, dtype=m.dtype).view(1, 1, 1, W) + area_t = m.sum(dim=(-1, -2)) + cx_t = (m * grid_x).sum(dim=(-1, -2)) / area_t.clamp(min=1) + present = area_t > 0 + frame_idx = torch.arange(T, device=m.device).unsqueeze(1) + first_t = torch.where(present, frame_idx, T).amin(dim=0) + sel = first_t.clamp(max=T - 1).unsqueeze(0) + cx = cx_t.gather(0, sel).squeeze(0) + area = area_t.gather(0, sel).squeeze(0) + return first_t.tolist(), (cx / W).tolist(), (area / (H * W)).tolist() def _subset_track_data(track_data, obj_indices): @@ -81,12 +87,26 @@ def _render_colored_masks(track_data, background="black"): masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest" ).view(T, N_obj, H, W) > 0.5 any_mask = masks_full.any(dim=1) - obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1) - color_overlay = colors[obj_idx_map] + color_overlay = colors[masks_full.to(torch.uint8).argmax(dim=1)] bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3) return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay)) +def _render_mask_as_identity(mask, background="black"): + """Plain comfy MASK (B,H,W) or (H,W) -> (B,H,W,3) rendered as a single identity (palette[0]) + on the given background. A batch is treated as multiple views of that one subject.""" + device = comfy.model_management.intermediate_device() + dtype = comfy.model_management.intermediate_dtype() + if mask.ndim == 2: + mask = mask.unsqueeze(0) + mask = mask.to(device=device, dtype=dtype) + B, H, W = mask.shape + bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0) + color = torch.tensor(DEFAULT_PALETTE[0], device=device, dtype=dtype).view(1, 1, 1, 3) + bg = torch.tensor(bg_rgb, device=device, dtype=dtype).view(1, 1, 1, 3) + return torch.where((mask > 0.5).unsqueeze(-1), color.expand(B, H, W, 3), bg.expand(B, H, W, 3)) + + def _extract_mask_to_28ch(rgb_video): """Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent (1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c) @@ -138,8 +158,8 @@ class WanSCAILToVideo(io.ComfyNode): io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."), io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."), - io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."), - io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask at the same resolution as reference_image."), + io.Image.Input("reference_image", optional=True, tooltip="Reference image. The first image is the primary reference (composite all identities onto it). SCAIL-2: extra batch images are used as additional views (back view, close-up, occluded background), each needing a matching reference_image_mask in that identity's color."), + io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask, batch matching reference_image (first = primary reference mask, rest = identity masks for the additional reference_image)."), io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."), io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."), io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."), @@ -171,19 +191,21 @@ class WanSCAILToVideo(io.ComfyNode): video_frame_offset -= prev_trimmed.shape[0] video_frame_offset = max(0, video_frame_offset) - ref_latent = None if reference_image is not None: - reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) - # Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte - if replacement_mode and reference_image_mask is not None: - rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1) - is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype) - reference_image = reference_image * is_char - ref_latent = vae.encode(reference_image[:, :, :, :3]) + ref_imgs = comfy.utils.common_upscale(reference_image.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + n_ref = ref_imgs.shape[0] + # SCAIL-2 multi-reference: the first image is the primary ref, the rest are additional references. - if ref_latent is not None: - positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) - negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + # Replacement Mode: composite each ref on black bg using its mask as alpha matte + if replacement_mode and reference_image_mask is not None: + rm = comfy.utils.common_upscale(reference_image_mask.movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1) + rm = rm[[min(i, rm.shape[0] - 1) for i in range(n_ref)]] + is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(ref_imgs.dtype) + ref_imgs = ref_imgs * is_char + # encode each ref individually so each stays a single latent frame (a batched encode would be treated as a video) + ref_latents = [vae.encode(ref_imgs[i:i + 1, :, :, :3]) for i in range(n_ref)] + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": ref_latents}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": ref_latents}, append=True) if clip_vision_output is not None: positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) @@ -221,11 +243,16 @@ class WanSCAILToVideo(io.ComfyNode): positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch}) negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch}) - if reference_image_mask is not None: - ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) - ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw) + # The ref mask binds reference frames to identities, so it only applies when there's a reference image. + if reference_image_mask is not None and reference_image is not None: + ref_mask_hw = comfy.utils.common_upscale(reference_image_mask.movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1) + n_masks = ref_mask_hw.shape[0] + n_ref = reference_image.shape[0] + + add_masks = [_extract_mask_to_28ch(ref_mask_hw[min(i, n_masks - 1)][None]) for i in range(1, n_ref)] + ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw[:1]) zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype) - ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1) + ref_mask_28ch = torch.cat(add_masks + [ref_mask_1f, zeros], dim=1) positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch}) negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch}) @@ -244,12 +271,9 @@ class WanSCAILToVideo(io.ComfyNode): class SCAIL2ColoredMask(io.ComfyNode): - """Render SAM3 tracks for the driving pose video and (optionally) the reference - image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by` - across both outputs guarantees identity K maps to the same color on both - sides, for multi-person workflow consistency. - reference_image_mask is always rendered black-bg (model convention) - pose_video_mask bg follows replacement_mode: black = Animation Mode, white = Replacement Mode + """Render SAM3 tracks for the driving pose video and reference image(s) into the + colored masks WanSCAILToVideo consumes. Shared `sort_by` keeps each identity on the + same color across both outputs. """ @classmethod @@ -260,10 +284,12 @@ class SCAIL2ColoredMask(io.ComfyNode): category="model/conditioning/wan/scail", inputs=[ SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."), - SAM3TrackData.Input("ref_track_data", optional=True, tooltip="SAM3 track of the reference image."), - io.String.Input("object_indices", default="", tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."), + io.MultiType.Input("ref_track_data", [SAM3TrackData, io.Mask], optional=True, display_name="reference_masks", + tooltip="SAM3 track of the reference image(s) (one identity per object, colored in batch order), or a plain MASK of the reference subject (rendered as a single identity)."), + io.String.Input("object_indices", default="", + tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."), io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right", - tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."), + tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). Objects that appear in earlier frames always come first; within a frame, left_to_right = leftmost object (by centroid at first appearance) gets the first color, area = biggest object (by mask area at first appearance) gets the first color; none = keep SAM3's order."), io.Boolean.Input("replacement_mode", default=False, tooltip="False = Animation Mode (pose_video_mask has black background, reference_image_mask has white background). " "True = Replacement Mode (pose_video_mask has white background, reference_image_mask has black background)."), @@ -280,11 +306,11 @@ class SCAIL2ColoredMask(io.ComfyNode): def _prep(td): masks_bool = _unpack(td) if sort_by != "none" and masks_bool is not None: - cx, area = _first_frame_cx_area(masks_bool) + first_t, cx, area = _first_appearance_cx_area(masks_bool) if sort_by == "left_to_right": - order = sorted(range(len(cx)), key=lambda i: cx[i]) + order = sorted(range(len(cx)), key=lambda i: (first_t[i], cx[i])) else: # "area" - order = sorted(range(len(area)), key=lambda i: -area[i]) + order = sorted(range(len(area)), key=lambda i: (first_t[i], -area[i])) td = _subset_track_data(td, order) if object_indices.strip(): indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] @@ -300,8 +326,10 @@ class SCAIL2ColoredMask(io.ComfyNode): ref_bg = "black" if replacement_mode else "white" if ref_track_data is not None: - ref = _prep(ref_track_data) - reference_image_mask = _render_colored_masks(ref, ref_bg) + if isinstance(ref_track_data, torch.Tensor): # plain comfy MASK + reference_image_mask = _render_mask_as_identity(ref_track_data, ref_bg) + else: + reference_image_mask = _render_colored_masks(_prep(ref_track_data), ref_bg) else: H, W = drv["orig_size"] fill_value = 1.0 if ref_bg == "white" else 0.0 diff --git a/comfy_extras/nodes_seed.py b/comfy_extras/nodes_seed.py new file mode 100644 index 000000000..e64f1d7e3 --- /dev/null +++ b/comfy_extras/nodes_seed.py @@ -0,0 +1,33 @@ +import sys +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +class SeedNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedNode", + display_name="Seed", + search_aliases=["seed", "random"], + category="utilities", + inputs=[ + io.Int.Input("seed", min=0, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed), + ], + outputs=[io.Int.Output(display_name="seed")], + ) + + @classmethod + def execute(cls, seed: int) -> io.NodeOutput: + return io.NodeOutput(seed) + + +class SeedExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [SeedNode] + + +async def comfy_entrypoint() -> SeedExtension: + return SeedExtension() diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index 6a78ffb47..ddfb4f2b0 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StableCascade_SuperResolutionControlnet", - category="experimental/stable_cascade", + category="experimental/stable cascade", is_experimental=True, inputs=[ io.Image.Input("image"), diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 97485c8c5..21929ae63 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -440,6 +440,57 @@ class JsonExtractString(io.ComfyNode): except (json.JSONDecodeError, TypeError): return io.NodeOutput("") + +def _dump_json(value, indent): + return json.dumps(value, ensure_ascii=False, indent=indent or None) + + +class ConvertDictionaryToString(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ConvertDictionaryToString", + display_name="Convert Dictionary to String", + category="text", + search_aliases=["json", "dict to json", "stringify", "serialize", "dict to string"], + inputs=[ + io.Dict.Input("dictionary"), + io.Int.Input("indent", default=2, min=0, max=8, + tooltip="Spaces per indent level. 0 produces compact single-line string."), + ], + outputs=[ + io.String.Output(), + ], + ) + + @classmethod + def execute(cls, dictionary, indent=2): + return io.NodeOutput(_dump_json(dictionary, indent)) + + +class ConvertArrayToString(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ConvertArrayToString", + display_name="Convert Array to String", + category="text", + search_aliases=["json", "list to json", "stringify", "serialize", "list to string", "array to json"], + inputs=[ + io.Array.Input("array"), + io.Int.Input("indent", default=2, min=0, max=8, + tooltip="Spaces per indent level. 0 produces compact single-line string."), + ], + outputs=[ + io.String.Output(), + ], + ) + + @classmethod + def execute(cls, array, indent=2): + return io.NodeOutput(_dump_json(array, indent)) + + class StringExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -457,6 +508,8 @@ class StringExtension(ComfyExtension): RegexExtract, RegexReplace, JsonExtractString, + ConvertDictionaryToString, + ConvertArrayToString, ] async def comfy_entrypoint() -> StringExtension: diff --git a/comfy_extras/nodes_triposplat.py b/comfy_extras/nodes_triposplat.py index 1848ad31a..c892213e4 100644 --- a/comfy_extras/nodes_triposplat.py +++ b/comfy_extras/nodes_triposplat.py @@ -65,7 +65,7 @@ class TripoSplatPreprocessImage(IO.ComfyNode): return IO.Schema( node_id="TripoSplatPreprocessImage", display_name="TripoSplat Preprocess Image", - category="3d/conditioning", + category="model/conditioning/triposplat", description="Crop center each image to a square canvas on a black background and add padding.", inputs=[ IO.Image.Input("image"), @@ -95,7 +95,7 @@ class TripoSplatConditioning(IO.ComfyNode): return IO.Schema( node_id="TripoSplatConditioning", display_name="TripoSplat Conditioning", - category="3d/conditioning", + category="model/conditioning/triposplat", description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative " "conditioning, and create the fixed size noise target (latent + camera) for the KSampler", inputs=[ @@ -143,7 +143,7 @@ class VAEDecodeTripoSplat(IO.ComfyNode): return IO.Schema( node_id="VAEDecodeTripoSplat", display_name="TripoSplat Decode", - category="3d/latent", + category="model/latent/triposplat", description="Decode the sampled TripoSplat latent into a 3D gaussian splat. " "Modify the number of gaussians to vary the density.", inputs=[ @@ -188,7 +188,7 @@ class TripoSplatSamplingPreview(IO.ComfyNode): return IO.Schema( node_id="TripoSplatSamplingPreview", display_name="TripoSplat Sampling Preview", - category="3d/latent", + category="model/latent/triposplat", description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded " "gaussian splat preview at each step.", inputs=[ diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 050a897dd..d3acc9ad0 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -27,6 +27,7 @@ class SaveWEBM(io.ComfyNode): ], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, + outputs=[io.Image.Output(display_name="images")] ) @classmethod @@ -69,7 +70,7 @@ class SaveWEBM(io.ComfyNode): container.mux(stream.encode()) container.close() - return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) + return io.NodeOutput(images, ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) class SaveVideo(io.ComfyNode): @classmethod @@ -89,6 +90,7 @@ class SaveVideo(io.ComfyNode): ], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, + outputs=[io.Video.Output("video")], ) @classmethod @@ -117,7 +119,7 @@ class SaveVideo(io.ComfyNode): metadata=saved_metadata ) - return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) + return io.NodeOutput(video, ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) class CreateVideo(io.ComfyNode): @@ -233,13 +235,8 @@ class VideoSlice(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Video Slice", - display_name="Video Slice", - search_aliases=[ - "trim video duration", - "skip first frames", - "frame load cap", - "start time", - ], + display_name="Trim Video", + search_aliases=["trim video duration", "skip first frames", "frame load cap", "start time"], category="video", essentials_category="Video Tools", inputs=[ diff --git a/comfyui_version.py b/comfyui_version.py index cee317f3d..8e9967f1b 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.25.0" +__version__ = "0.27.0" diff --git a/execution.py b/execution.py index 9e16e451d..c45317593 100644 --- a/execution.py +++ b/execution.py @@ -1308,6 +1308,25 @@ class PromptQueue: queued = copy.copy(self.queue) return (running, queued) + def interrupt_if_running(self, prompt_id): + """Interrupt the running prompt with this id, atomically. + + Checks the live running set and signals the interrupt under the queue + mutex, so the worker cannot move the job to done (and start the next + prompt) in between. Returns True if a matching job was running and an + interrupt was signalled, False otherwise. The atomicity is what keeps a + cancel from landing on an unrelated prompt that started after a separate + is-running check: the global interrupt flag is reset at the start of + every prompt (execute_async), so a job that finishes before consuming + the flag cannot leak the interrupt onto its successor. + """ + with self.mutex: + for item in self.currently_running.values(): + if item[1] == prompt_id: + nodes.interrupt_processing() + return True + return False + def get_tasks_remaining(self): with self.mutex: return len(self.queue) + len(self.currently_running) diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index 9c395c0b2..6a31d8a63 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -8,21 +8,37 @@ # # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads # #is_default: true # checkpoints: models/checkpoints/ +# configs: models/configs/ +# loras: models/loras/ +# vae: models/vae/ # text_encoders: | # models/text_encoders/ -# models/clip/ # legacy location still supported -# clip_vision: models/clip_vision/ -# configs: models/configs/ -# controlnet: models/controlnet/ +# models/clip/ # diffusion_models: | -# models/diffusion_models -# models/unet +# models/unet/ +# models/diffusion_models/ +# clip_vision: models/clip_vision/ +# style_models: models/style_models/ # embeddings: models/embeddings/ -# loras: models/loras/ +# diffusers: models/diffusers/ +# vae_approx: models/vae_approx/ +# controlnet: | +# models/controlnet/ +# models/t2i_adapter/ +# gligen: models/gligen/ # upscale_models: models/upscale_models/ -# vae: models/vae/ -# audio_encoders: models/audio_encoders/ +# latent_upscale_models: models/latent_upscale_models/ +# custom_nodes: custom_nodes/ +# hypernetworks: models/hypernetworks/ +# photomaker: models/photomaker/ +# classifiers: models/classifiers/ # model_patches: models/model_patches/ +# audio_encoders: models/audio_encoders/ +# background_removal: models/background_removal/ +# frame_interpolation: models/frame_interpolation/ +# geometry_estimation: models/geometry_estimation/ +# optical_flow: models/optical_flow/ +# detection: models/detection/ #config for a1111 ui @@ -45,8 +61,7 @@ # controlnet: models/ControlNet -# For a full list of supported keys (style_models, vae_approx, hypernetworks, photomaker, -# model_patches, audio_encoders, classifiers, etc.) see folder_paths.py. +# For the canonical list of supported keys and extensions, see folder_paths.py. #other_ui: # base_path: path/to/ui diff --git a/main.py b/main.py index ad5c11e16..20ec83c9e 100644 --- a/main.py +++ b/main.py @@ -403,7 +403,7 @@ def prompt_worker(q, server_instance): hook_breaker_ac10a0.restore_functions() if not asset_seeder.is_disabled(): - asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True) + asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=args.enable_asset_hashing) asset_seeder.resume() @@ -458,7 +458,7 @@ def setup_database(): if dependencies_available(): init_db() if args.enable_assets: - if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True): + if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=args.enable_asset_hashing): logging.info("Background asset scan initiated for models, input, output") except Exception as e: if "database is locked" in str(e): @@ -557,8 +557,13 @@ if __name__ == "__main__": logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") if args.disable_dynamic_vram: - logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.") - + logging.warning( + "Dynamic vram disabled with argument. If you have any issues with " + "dynamic vram enabled please give us a detailed reports as this " + "argument will be removed soon. If you use gguf we recommend keeping " + "dynamic vram enabled and using native ComfyUI model formats instead. " + "ComfyUI native formats like fp8 will be faster even if they are larger than your memory." + ) event_loop, _, start_all_func = start_comfyui() try: x = start_all_func() diff --git a/nodes.py b/nodes.py index 0eff30ef2..200d7c6a5 100644 --- a/nodes.py +++ b/nodes.py @@ -20,8 +20,6 @@ from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch -sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) - import comfy.diffusers_load import comfy.samplers import comfy.sample @@ -161,6 +159,29 @@ class ConditioningConcat: return (out, ) +class ConditioningMultiply: + SEARCH_ALIASES = ["scale conditioning", "scale prompt", "multiply conditioning", "multiply prompt"] + + @classmethod + def INPUT_TYPES(cls): + return {"required": {"conditioning": ("CONDITIONING", ), + "multiplier": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "multiply" + CATEGORY = "model/conditioning/transform" + + def multiply(self, conditioning, multiplier): + c = [] + for t in conditioning: + values = {} + pooled_output = t[1].get("pooled_output", None) + if pooled_output is not None: + values["pooled_output"] = pooled_output * multiplier + scaled = node_helpers.conditioning_set_values([[t[0] * multiplier, t[1]]], values)[0] + c.append(scaled) + return (c,) + class ConditioningSetArea: SEARCH_ALIASES = ["regional prompt", "area prompt", "spatial conditioning", "localized prompt"] @@ -328,7 +349,7 @@ class VAEDecodeTiled: RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" - CATEGORY = "experimental" + CATEGORY = "model/latent" def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): if tile_size < overlap * 4: @@ -375,7 +396,7 @@ class VAEEncodeTiled: RETURN_TYPES = ("LATENT",) FUNCTION = "encode" - CATEGORY = "experimental" + CATEGORY = "model/latent" def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) @@ -482,16 +503,18 @@ class SaveLatent: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), - "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - RETURN_TYPES = () + return { "required": { + "samples": ("LATENT",), + "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("samples",) FUNCTION = "save" OUTPUT_NODE = True - CATEGORY = "experimental" + CATEGORY = "model/latent" def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) @@ -524,7 +547,7 @@ class SaveLatent: output["latent_format_version_0"] = torch.tensor([]) comfy.utils.save_torch_file(output, file, metadata=metadata) - return { "ui": { "latents": results } } + return { "ui": { "latents": results }, "result": (samples,) } class LoadLatent: @@ -536,7 +559,7 @@ class LoadLatent: files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] return {"required": {"latent": [sorted(files), ]}, } - CATEGORY = "experimental" + CATEGORY = "model/latent" RETURN_TYPES = ("LATENT", ) FUNCTION = "load" @@ -969,7 +992,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "joyimage"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu", "krea2", "joyimage"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -1629,14 +1652,18 @@ class SaveImage: return { "required": { "images": ("IMAGE", {"tooltip": "The images to save."}), - "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) + "filename_prefix": ("STRING", { + "default": "ComfyUI", + "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes." + }) }, "hidden": { "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" }, } - RETURN_TYPES = () + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) FUNCTION = "save_images" OUTPUT_NODE = True @@ -1672,7 +1699,7 @@ class SaveImage: }) counter += 1 - return { "ui": { "images": results } } + return { "ui": { "images": results }, "result" : (images,) } class PreviewImage(SaveImage): def __init__(self): @@ -2046,6 +2073,7 @@ NODE_CLASS_MAPPINGS = { "ConditioningAverage": ConditioningAverage, "ConditioningCombine": ConditioningCombine, "ConditioningConcat": ConditioningConcat, + "ConditioningMultiply": ConditioningMultiply, "ConditioningSetArea": ConditioningSetArea, "ConditioningSetAreaPercentage": ConditioningSetAreaPercentage, "ConditioningSetAreaStrength": ConditioningSetAreaStrength, @@ -2117,6 +2145,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ConditioningAverage ": "Conditioning (Average)", "ConditioningAverage": "Conditioning (Average)", "ConditioningConcat": "Conditioning (Concat)", + "ConditioningMultiply": "Conditioning (Multiply)", "ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)", "ConditioningSetAreaStrength": "Conditioning (Set Area Strength)", @@ -2126,6 +2155,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "GLIGENTextBoxApply": "Apply GLIGEN Text Box", "ConditioningZeroOut": "Conditioning Zero Out", # Latent + "LoadLatent": "Load Latent", + "SaveLatent": "Save Latent", "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "SetLatentNoiseMask": "Set Latent Noise Mask", "VAEDecode": "VAE Decode", @@ -2160,7 +2191,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageSharpen": "Sharpen Image", "ImageScaleToTotalPixels": "Scale Image to Total Pixels", "GetImageSize": "Get Image Size", - # experimental "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)", } @@ -2299,6 +2329,9 @@ async def init_external_custom_nodes(): Returns: None """ + # TODO: remove at some point when custom nodes don't break. + sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) + base_node_names = set(NODE_CLASS_MAPPINGS.keys()) node_paths = folder_paths.get_folder_paths("custom_nodes") node_import_times = [] @@ -2367,6 +2400,8 @@ async def init_builtin_extra_nodes(): "nodes_images.py", "nodes_video_model.py", "nodes_ideogram4.py", + "nodes_bounding_boxes.py", + "nodes_json_prompt.py", "nodes_train.py", "nodes_dataset.py", "nodes_sag.py", @@ -2426,6 +2461,7 @@ async def init_builtin_extra_nodes(): "nodes_context_windows.py", "nodes_qwen.py", "nodes_joyimage.py", + "nodes_boogu.py", "nodes_chroma_radiance.py", "nodes_pid.py", "nodes_model_patch.py", @@ -2466,6 +2502,7 @@ async def init_builtin_extra_nodes(): "nodes_gaussian_splat.py", "nodes_triposplat.py", "nodes_depth_anything_3.py", + "nodes_seed.py", ] import_failed = [] diff --git a/openapi.yaml b/openapi.yaml index 82ff5b003..c6a8621cc 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -55,6 +55,12 @@ components: description: URL for asset preview/thumbnail format: uri type: string + short_url: + description: Durable, owner-gated short link to this asset's content (relative `/api/s/{id}` path). Stable across the underlying signed URL's expiry — resolving it re-mints a fresh signed URL on every request — so it is safe to persist or share into chat, unlike `preview_url`. Only the minting user can resolve it. Omitted when the short-link surface is disabled or the asset has no resolvable content hash. + nullable: true + type: string + x-runtime: + - cloud size: description: Size of the asset in bytes format: int64 @@ -673,6 +679,35 @@ components: - created_at - updated_at type: object + JobsCancelRequest: + additionalProperties: false + description: Request to cancel multiple jobs by ID. + properties: + job_ids: + description: Job identifiers (UUIDs) to cancel. + items: + format: uuid + type: string + maxItems: 100 + minItems: 1 + type: array + required: + - job_ids + type: object + JobsCancelResponse: + description: Response for POST /api/jobs/cancel. + properties: + cancelled: + description: | + Job IDs for which a cancel event was successfully dispatched by this + call. Jobs already in a terminal or cancelling state are idempotently + skipped and will not appear here. + items: + type: string + type: array + required: + - cancelled + type: object JobsListResponse: description: Paginated list of jobs for the authenticated user. properties: @@ -1006,7 +1041,7 @@ components: description: If true, clear all pending jobs from the queue type: boolean delete: - description: Array of PENDING job IDs to cancel + description: Array of job IDs to cancel; pending and running jobs transition to cancelled items: type: string type: array @@ -1657,6 +1692,12 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' description: Unsupported media type + "422": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Validation error (e.g., disallowed model_type tag) "500": content: application/json: @@ -1822,6 +1863,83 @@ paths: summary: Update asset metadata tags: - file + /api/assets/{id}/content: + get: + description: | + Returns the binary content of an asset by ID. + + The contract is the same across runtimes — "GET this path and you + receive the asset's bytes" — but the mechanism differs: + - **Local ComfyUI** streams the bytes directly (`200`, + `application/octet-stream`). + - **Cloud** does not proxy large files; it responds `302` with a + `Location` redirect to a short-lived signed storage URL. Clients that + follow redirects (browsers, `fetch`/XHR, ``/`