diff --git a/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat b/.ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat similarity index 66% rename from .ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat rename to .ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat index cece0aeb2..94ad31942 100755 --- a/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat +++ b/.ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat @@ -1,2 +1,2 @@ -.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-dynamic-vram pause diff --git a/.github/workflows/openapi-lint.yml b/.github/workflows/openapi-lint.yml new file mode 100644 index 000000000..be949de2a --- /dev/null +++ b/.github/workflows/openapi-lint.yml @@ -0,0 +1,31 @@ +name: OpenAPI Lint + +on: + pull_request: + paths: + - 'openapi.yaml' + - '.spectral.yaml' + - '.github/workflows/openapi-lint.yml' + +permissions: + contents: read + +jobs: + spectral: + name: Run Spectral + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + + - name: Install Spectral + run: npm install -g @stoplight/spectral-cli@6 + + - name: Lint openapi.yaml + run: spectral lint openapi.yaml --ruleset .spectral.yaml --fail-severity=error diff --git a/.gitignore b/.gitignore index 0ab4ba75e..fc426eda4 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ web_custom_versions/ .DS_Store filtered-openapi.yaml uv.lock +.comfy_environment diff --git a/.spectral.yaml b/.spectral.yaml new file mode 100644 index 000000000..4bb4a4a94 --- /dev/null +++ b/.spectral.yaml @@ -0,0 +1,91 @@ +extends: + - spectral:oas + +# Severity levels: error, warn, info, hint, off +# Rules from the built-in "spectral:oas" ruleset are active by default. +# Below we tune severity and add custom rules for our conventions. +# +# This ruleset mirrors Comfy-Org/cloud/.spectral.yaml so specs across the +# organization are linted against a single consistent standard. + +rules: + # ----------------------------------------------------------------------- + # Built-in rule severity overrides + # ----------------------------------------------------------------------- + operation-operationId: error + operation-description: warn + operation-tag-defined: error + info-contact: off + info-description: warn + no-eval-in-markdown: error + no-$ref-siblings: error + + # ----------------------------------------------------------------------- + # Custom rules: naming conventions + # ----------------------------------------------------------------------- + + # Property names should be snake_case + property-name-snake-case: + description: Property names must be snake_case + severity: warn + given: "$.components.schemas.*.properties[*]~" + then: + function: pattern + functionOptions: + match: "^[a-z][a-z0-9]*(_[a-z0-9]+)*$" + + # Operation IDs should be camelCase + operation-id-camel-case: + description: Operation IDs must be camelCase + severity: warn + given: "$.paths.*.*.operationId" + then: + function: pattern + functionOptions: + match: "^[a-z][a-zA-Z0-9]*$" + + # ----------------------------------------------------------------------- + # Custom rules: response conventions + # ----------------------------------------------------------------------- + + # Error responses (4xx, 5xx) should use a consistent shape + error-response-schema: + description: Error responses should reference a standard error schema + severity: hint + given: "$.paths.*.*.responses[?(@property >= '400' && @property < '600')].content['application/json'].schema" + then: + field: "$ref" + function: truthy + + # All 2xx responses with JSON body should have a schema + response-schema-defined: + description: Success responses with JSON content should define a schema + severity: warn + given: "$.paths.*.*.responses[?(@property >= '200' && @property < '300')].content['application/json']" + then: + field: schema + function: truthy + + # ----------------------------------------------------------------------- + # Custom rules: best practices + # ----------------------------------------------------------------------- + + # Path parameters must have a description + path-param-description: + description: Path parameters should have a description + severity: warn + given: + - "$.paths.*.parameters[?(@.in == 'path')]" + - "$.paths.*.*.parameters[?(@.in == 'path')]" + then: + field: description + function: truthy + + # Schemas should have a description + schema-description: + description: Component schemas should have a description + severity: hint + given: "$.components.schemas.*" + then: + field: description + function: truthy diff --git a/CODEOWNERS b/CODEOWNERS index e693955a0..946dbf946 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,2 +1,2 @@ # Admins -* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 +* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai diff --git a/README.md b/README.md index f05311421..0fd317d0a 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@
# ComfyUI -**The most powerful and modular visual AI engine and application.** +**The most powerful and modular AI engine for content creation.** [![Website][website-shield]][website-url] @@ -31,10 +31,16 @@ [github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest [github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases -![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe) +ComfyUI Screenshot +
-ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS. +ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more... +- ComfyUI natively supports the latest open-source state of the art models. +- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc. +- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud. +- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode. +- It integrates seamlessly into production pipelines with our API endpoints. ## Get Started @@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) - [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/) + - Ernie Image - Image Editing Models - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) @@ -126,7 +133,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** - - Releases a new stable version (e.g., v0.7.0) roughly every week. + - Releases a new major stable version (e.g., v0.7.0) roughly every 2 weeks. - Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release. - Minor versions will be used for releases off the master branch. - Patch versions may still be used for releases on the master branch in cases where a backport would not make sense. @@ -193,13 +200,15 @@ If you have trouble extracting it, right click the file -> properties -> unblock The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start. -#### Alternative Downloads: +#### All Official Portable Downloads: [Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) -[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z) +[Portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z) -[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). +[Portable for Nvidia GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) (supports 20 series and above). + +[Portable for Nvidia GPUs with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). #### How do I share models between another UI and ComfyUI? diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cef1a5e6b..9dadb0093 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -91,6 +91,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE" parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.") parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.") +parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.") class LatentPreviewMethod(enum.Enum): NoPreviews = "none" @@ -237,6 +238,8 @@ database_default_path = os.path.abspath( ) parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") +parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button") +parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/comfy/deploy_environment.py b/comfy/deploy_environment.py new file mode 100644 index 000000000..8c99a3584 --- /dev/null +++ b/comfy/deploy_environment.py @@ -0,0 +1,34 @@ +import functools +import logging +import os + +logger = logging.getLogger(__name__) + +_DEFAULT_DEPLOY_ENV = "local-git" +_ENV_FILENAME = ".comfy_environment" + +# Resolve the ComfyUI install directory (the parent of this `comfy/` package). +# We deliberately avoid `folder_paths.base_path` here because that is overridden +# by the `--base-directory` CLI arg to a user-supplied path, whereas the +# `.comfy_environment` marker is written by launchers/installers next to the +# ComfyUI install itself. +_COMFY_INSTALL_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + + +@functools.cache +def get_deploy_environment() -> str: + env_file = os.path.join(_COMFY_INSTALL_DIR, _ENV_FILENAME) + try: + with open(env_file, encoding="utf-8") as f: + # Cap the read so a malformed or maliciously crafted file (e.g. + # a single huge line with no newline) can't blow up memory. + first_line = f.readline(128).strip() + value = "".join(c for c in first_line if 32 <= ord(c) < 127) + if value: + return value + except FileNotFoundError: + pass + except Exception as e: + logger.error("Failed to read %s: %s", env_file, e) + + return _DEFAULT_DEPLOY_ENV diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 6978eb717..d33bc7199 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1810,3 +1810,102 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False): """Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023).""" return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2) + + +@torch.no_grad() +def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None, + num_frame_per_block=1): + """ + Autoregressive video sampler: block-by-block denoising with KV cache + and flow-match re-noising for Causal Forcing / Self-Forcing models. + + Requires a Causal-WAN compatible model (diffusion_model must expose + init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W]. + + All AR-loop parameters are passed via the SamplerARVideo node, not read + from the checkpoint or transformer_options. + """ + extra_args = {} if extra_args is None else extra_args + model_options = extra_args.get("model_options", {}) + transformer_options = model_options.get("transformer_options", {}) + + if x.ndim != 5: + raise ValueError( + f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. " + "This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)." + ) + + inner_model = model.inner_model.inner_model + causal_model = inner_model.diffusion_model + + if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")): + raise TypeError( + "ar_video sampler requires a Causal-WAN compatible model whose diffusion_model " + "exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint " + "does not support this interface — choose a different sampler." + ) + + seed = extra_args.get("seed", 0) + + bs, c, lat_t, lat_h, lat_w = x.shape + frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division + num_blocks = -(-lat_t // num_frame_per_block) # ceiling division + device = x.device + model_dtype = inner_model.get_dtype() + + kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype) + crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype) + + output = torch.zeros_like(x) + s_in = x.new_ones([x.shape[0]]) + current_start_frame = 0 + num_sigma_steps = len(sigmas) - 1 + total_real_steps = num_blocks * num_sigma_steps + step_count = 0 + + try: + for block_idx in trange(num_blocks, disable=disable): + bf = min(num_frame_per_block, lat_t - current_start_frame) + fs, fe = current_start_frame, current_start_frame + bf + noisy_input = x[:, :, fs:fe] + + ar_state = { + "start_frame": current_start_frame, + "kv_caches": kv_caches, + "crossattn_caches": crossattn_caches, + } + transformer_options["ar_state"] = ar_state + + for i in range(num_sigma_steps): + denoised = model(noisy_input, sigmas[i] * s_in, **extra_args) + + if callback is not None: + scaled_i = step_count * num_sigma_steps // total_real_steps + callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i], + "sigma_hat": sigmas[i], "denoised": denoised}) + + if sigmas[i + 1] == 0: + noisy_input = denoised + else: + sigma_next = sigmas[i + 1] + torch.manual_seed(seed + block_idx * 1000 + i) + fresh_noise = torch.randn_like(denoised) + noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise + + for cache in kv_caches: + cache["end"] -= bf * frame_seq_len + + step_count += 1 + + output[:, :, fs:fe] = noisy_input + + for cache in kv_caches: + cache["end"] -= bf * frame_seq_len + zero_sigma = sigmas.new_zeros([1]) + _ = model(noisy_input, zero_sigma * s_in, **extra_args) + + current_start_frame += bf + finally: + transformer_options.pop("ar_state", None) + + return output diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 6f2ba41ef..3fb87b4a3 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import ( from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector import comfy.ldm.common_dit +import comfy.model_prefetch class CompressedTimestep: """Store video timestep embeddings in compressed form using per-frame indexing.""" @@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel): """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options) # Process transformer blocks for i, block in enumerate(self.transformer_blocks): + comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block) if ("double_block", i) in blocks_replace: def block_wrap(args): @@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel): a_prompt_timestep=a_prompt_timestep, ) + comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None) + return [vx, ax] def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b193fe5e8..a68cb8439 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management +TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5) + if model_management.xformers_enabled(): import xformers import xformers.ops @@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape b, _, dim_head = q.shape dim_head //= heads - scale = dim_head ** -0.5 + if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]: + n_rep = q.shape[-3] // k.shape[-3] + k = k.repeat_interleave(n_rep, dim=-3) + v = v.repeat_interleave(n_rep, dim=-3) + + scale = kwargs.get("scale", dim_head ** -0.5) h = heads if skip_reshape: @@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, b, _, dim_head = query.shape dim_head //= heads + if "scale" in kwargs: + # Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head)) + query = query * (kwargs["scale"] * dim_head ** 0.5) + if skip_reshape: query = query.reshape(b * heads, -1, dim_head) value = value.reshape(b * heads, -1, dim_head) @@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape b, _, dim_head = q.shape dim_head //= heads - scale = dim_head ** -0.5 + scale = kwargs.get("scale", dim_head ** -0.5) if skip_reshape: q, k, v = map( @@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.ndim == 3: mask = mask.unsqueeze(1) + # Pass through extra SDPA kwargs (scale, enable_gqa) if provided + # enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above + sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",) + sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys} + if SDP_BATCH_LIMIT >= b: - out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], attn_mask=m, - dropout_p=0.0, is_causal=False + dropout_p=0.0, is_causal=False, **sdpa_extra ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) return out diff --git a/comfy/ldm/wan/ar_model.py b/comfy/ldm/wan/ar_model.py new file mode 100644 index 000000000..d72f53602 --- /dev/null +++ b/comfy/ldm/wan/ar_model.py @@ -0,0 +1,276 @@ +""" +CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for +autoregressive (frame-by-frame) video generation via Causal Forcing. + +Weight-compatible with the standard WanModel -- same layer names, same shapes. +The difference is purely in the forward pass: this model processes one temporal +block at a time and maintains a KV cache across blocks. + +Reference: https://github.com/thu-ml/Causal-Forcing +""" + +import torch +import torch.nn as nn + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.wan.model import ( + sinusoidal_embedding_1d, + repeat_e, + WanModel, + WanAttentionBlock, +) +import comfy.ldm.common_dit +import comfy.model_management + + +class CausalWanSelfAttention(nn.Module): + """Self-attention with KV cache support for autoregressive inference.""" + + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, + eps=1e-6, operation_settings={}): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qk_norm = qk_norm + self.eps = eps + + ops = operation_settings.get("operations") + device = operation_settings.get("device") + dtype = operation_settings.get("dtype") + + self.q = ops.Linear(dim, dim, device=device, dtype=dtype) + self.k = ops.Linear(dim, dim, device=device, dtype=dtype) + self.v = ops.Linear(dim, dim, device=device, dtype=dtype) + self.o = ops.Linear(dim, dim, device=device, dtype=dtype) + self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity() + self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity() + + def forward(self, x, freqs, kv_cache=None, transformer_options={}): + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs) + k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs) + v = self.v(x).view(b, s, n, d) + + if kv_cache is None: + x = optimized_attention( + q.view(b, s, n * d), + k.view(b, s, n * d), + v.view(b, s, n * d), + heads=self.num_heads, + transformer_options=transformer_options, + ) + else: + end = kv_cache["end"] + new_end = end + s + + # Roped K and plain V go into cache + kv_cache["k"][:, end:new_end] = k + kv_cache["v"][:, end:new_end] = v + kv_cache["end"] = new_end + + x = optimized_attention( + q.view(b, s, n * d), + kv_cache["k"][:, :new_end].view(b, new_end, n * d), + kv_cache["v"][:, :new_end].view(b, new_end, n * d), + heads=self.num_heads, + transformer_options=transformer_options, + ) + + x = self.o(x) + return x + + +class CausalWanAttentionBlock(WanAttentionBlock): + """Transformer block with KV-cached self-attention and cross-attention caching.""" + + def __init__(self, cross_attn_type, dim, ffn_dim, num_heads, + window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, + eps=1e-6, operation_settings={}): + super().__init__(cross_attn_type, dim, ffn_dim, num_heads, + window_size, qk_norm, cross_attn_norm, eps, + operation_settings=operation_settings) + self.self_attn = CausalWanSelfAttention( + dim, num_heads, window_size, qk_norm, eps, + operation_settings=operation_settings) + + def forward(self, x, e, freqs, context, context_img_len=257, + kv_cache=None, crossattn_cache=None, transformer_options={}): + if e.ndim < 4: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) + else: + e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2) + + # Self-attention with optional KV cache + x = x.contiguous() + y = self.self_attn( + torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), + freqs, kv_cache=kv_cache, transformer_options=transformer_options) + x = torch.addcmul(x, y, repeat_e(e[2], x)) + del y + + # Cross-attention with optional caching + if crossattn_cache is not None and crossattn_cache.get("is_init"): + q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x))) + x_ca = optimized_attention( + q, crossattn_cache["k"], crossattn_cache["v"], + heads=self.num_heads, transformer_options=transformer_options) + x = x + self.cross_attn.o(x_ca) + else: + x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + if crossattn_cache is not None: + crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context)) + crossattn_cache["v"] = self.cross_attn.v(context) + crossattn_cache["is_init"] = True + + # FFN + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) + x = torch.addcmul(x, y, repeat_e(e[5], x)) + return x + + +class CausalWanModel(WanModel): + """ + Wan 2.1 diffusion backbone with causal KV-cache support. + + Same weight structure as WanModel -- loads identical state dicts. + Adds forward_block() for frame-by-frame autoregressive inference. + """ + + def __init__(self, + model_type='t2v', + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + image_model=None, + device=None, + dtype=None, + operations=None): + super().__init__( + model_type=model_type, patch_size=patch_size, text_len=text_len, + in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, + text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, + num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, + wan_attn_block_class=CausalWanAttentionBlock, + device=device, dtype=dtype, operations=operations) + + def forward_block(self, x, timestep, context, start_frame, + kv_caches, crossattn_caches, clip_fea=None): + """ + Forward one temporal block for autoregressive inference. + + Args: + x: [B, C, block_frames, H, W] input latent for the current block + timestep: [B, block_frames] per-frame timesteps + context: [B, L, text_dim] raw text embeddings (pre-text_embedding) + start_frame: temporal frame index for RoPE offset + kv_caches: list of per-layer KV cache dicts + crossattn_caches: list of per-layer cross-attention cache dicts + clip_fea: optional CLIP features for I2V + + Returns: + flow_pred: [B, C_out, block_frames, H, W] flow prediction + """ + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + bs, c, t, h, w = x.shape + + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + x = x.flatten(2).transpose(1, 2) + + # Per-frame time embedding + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) + e = e.reshape(timestep.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # Text embedding (reuses crossattn_cache after first block) + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None and self.img_emb is not None: + context_clip = self.img_emb(clip_fea) + context = torch.concat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + # RoPE for current block's temporal position + freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype) + + # Transformer blocks + for i, block in enumerate(self.blocks): + x = block(x, e=e0, freqs=freqs, context=context, + context_img_len=context_img_len, + kv_cache=kv_caches[i], + crossattn_cache=crossattn_caches[i]) + + # Head + x = self.head(x, e) + + # Unpatchify + x = self.unpatchify(x, grid_sizes) + return x[:, :, :t, :h, :w] + + def init_kv_caches(self, batch_size, max_seq_len, device, dtype): + """Create fresh KV caches for all layers.""" + caches = [] + for _ in range(self.num_layers): + caches.append({ + "k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), + "v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype), + "end": 0, + }) + return caches + + def init_crossattn_caches(self, batch_size, device, dtype): + """Create fresh cross-attention caches for all layers.""" + caches = [] + for _ in range(self.num_layers): + caches.append({"is_init": False}) + return caches + + def reset_kv_caches(self, kv_caches): + """Reset KV caches to empty (reuse allocated memory).""" + for cache in kv_caches: + cache["end"] = 0 + + def reset_crossattn_caches(self, crossattn_caches): + """Reset cross-attention caches.""" + for cache in crossattn_caches: + cache["is_init"] = False + + @property + def head_dim(self): + return self.dim // self.num_heads + + def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + ar_state = transformer_options.get("ar_state") + if ar_state is not None: + bs = x.shape[0] + block_frames = x.shape[2] + t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames) + return self.forward_block( + x=x, timestep=t_per_frame, context=context, + start_frame=ar_state["start_frame"], + kv_caches=ar_state["kv_caches"], + crossattn_caches=ar_state["crossattn_caches"], + clip_fea=clip_fea, + ) + + return super().forward(x, timestep, context, clip_fea=clip_fea, + time_dim_concat=time_dim_concat, + transformer_options=transformer_options, **kwargs) diff --git a/comfy/lora.py b/comfy/lora.py index e4337c729..db8f16bcb 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -17,6 +17,7 @@ """ from __future__ import annotations +import comfy.memory_management import comfy.utils import comfy.model_management import comfy.model_base @@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori weight = old_weight return weight + +def prefetch_prepared_value(value, allocate_buffer, stream): + if isinstance(value, torch.Tensor): + dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) + comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) + return comfy.memory_management.interpret_gathered_like([value], dest)[0] + elif isinstance(value, weight_adapter.WeightAdapterBase): + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) + elif isinstance(value, tuple): + return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) + elif isinstance(value, list): + return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] + + return value diff --git a/comfy/model_base.py b/comfy/model_base.py index 50dab5782..57a1e44d2 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.wan.model_animate +import comfy.ldm.wan.ar_model import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model @@ -214,6 +215,11 @@ class BaseModel(torch.nn.Module): if "latent_shapes" in extra_conds: xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) + transformer_options = transformer_options.copy() + transformer_options["prefetch_dynamic_vbars"] = ( + self.current_patcher is not None and self.current_patcher.is_dynamic() + ) + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) if len(model_output) > 1 and not torch.is_tensor(model_output): model_output, _ = utils.pack_latents(model_output) @@ -1360,6 +1366,13 @@ class WAN21(BaseModel): return out +class WAN21_CausalAR(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.wan.ar_model.CausalWanModel) + self.image_to_video = False + + class WAN21_Vace(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) diff --git a/comfy/model_management.py b/comfy/model_management.py index f86e2a4aa..21738a4c7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.vram_buffer class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -720,13 +721,15 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu else: minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory()) - models_temp = set() + # Order-preserving dedup. A plain set() would randomize iteration order across runs + models_temp = {} for m in models: - models_temp.add(m) + models_temp[m] = None for mm in m.model_patches_models(): - models_temp.add(mm) + models_temp[mm] = None - models = models_temp + models = list(models_temp) + models.reverse() models_to_load = [] @@ -1175,6 +1178,10 @@ stream_counters = {} STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) +STREAM_AIMDO_CAST_BUFFERS = {} +LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) + +DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 def get_cast_buffer(offload_stream, device, size, ref): global LARGEST_CASTED_WEIGHT @@ -1208,13 +1215,26 @@ def get_cast_buffer(offload_stream, device, size, ref): return cast_buffer +def get_aimdo_cast_buffer(offload_stream, device): + cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None: + cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) + STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer + + return cast_buffer def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT + global LARGEST_AIMDO_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) - for offload_stream in STREAM_CAST_BUFFERS: - offload_stream.synchronize() + LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): + if offload_stream is not None: + offload_stream.synchronize() synchronize() + STREAM_CAST_BUFFERS.clear() + STREAM_AIMDO_CAST_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e259aed63..7d2d6883f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -121,9 +121,20 @@ class LowVramPatch: self.patches = patches self.convert_func = convert_func # TODO: remove self.set_func = set_func + self.prepared_patches = None + + def prepare(self, allocate_buffer, stream): + self.prepared_patches = [ + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) + for patch in self.patches[self.key] + ] + + def clear_prepared(self): + self.prepared_patches = None def __call__(self, weight): - return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) + patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key] + return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype) LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py new file mode 100644 index 000000000..72e11dec6 --- /dev/null +++ b/comfy/model_prefetch.py @@ -0,0 +1,66 @@ +import comfy_aimdo.model_vbar +import comfy.model_management +import comfy.ops + +PREFETCH_QUEUES = [] + +def cleanup_prefetched_modules(comfy_modules): + for s in comfy_modules: + prefetch = getattr(s, "_prefetch", None) + if prefetch is None: + continue + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + lowvram_fn.clear_prepared() + if prefetch["signature"] is not None: + comfy_aimdo.model_vbar.vbar_unpin(s._v) + delattr(s, "_prefetch") + +def cleanup_prefetch_queues(): + global PREFETCH_QUEUES + + for queue in PREFETCH_QUEUES: + for entry in queue: + if entry is None or not isinstance(entry, tuple): + continue + _, prefetch_state = entry + comfy_modules = prefetch_state[1] + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + PREFETCH_QUEUES = [] + +def prefetch_queue_pop(queue, device, module): + if queue is None: + return + + consumed = queue.pop(0) + if consumed is not None: + offload_stream, prefetch_state = consumed + if offload_stream is not None: + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + _, comfy_modules = prefetch_state + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + + prefetch = queue[0] + if prefetch is not None: + comfy_modules = [] + for s in prefetch.modules(): + if hasattr(s, "_v"): + comfy_modules.append(s) + + offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) + comfy.model_management.sync_stream(device, offload_stream) + queue[0] = (offload_stream, (prefetch, comfy_modules)) + +def make_prefetch_queue(queue, device, transformer_options): + if (not transformer_options.get("prefetch_dynamic_vbars", False) + or comfy.model_management.NUM_STREAMS == 0 + or comfy.model_management.is_device_cpu(device) + or not comfy.model_management.device_supports_non_blocking(device)): + return None + + queue = [None] + queue + [None] + PREFETCH_QUEUES.append(queue) + return queue diff --git a/comfy/ops.py b/comfy/ops.py index 050f7cda0..585c185a3 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys): setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad)) -def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): - #vbar doesn't support CPU weights, but some custom nodes have weird paths - #that might switch the layer to the CPU and expect it to work. We have to take - #a clone conservatively as we are mmapped and some SFT files are packed misaligned - #If you are a custom node author reading this, please move your layer to the GPU - #or declare your ModelPatcher as CPU in the first place. - if comfy.model_management.is_device_cpu(device): - materialize_meta_param(s, ["weight", "bias"]) - weight = s.weight.to(dtype=dtype, copy=True) - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - bias = None - if s.bias is not None: - bias = s.bias.to(dtype=bias_dtype, copy=True) - return weight, bias, (None, None, None) - +# FIXME: add n=1 cache hit fast path +def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking): offload_stream = None - xfer_dest = None + cast_buffer = None + cast_buffer_offset = 0 + + def ensure_offload_stream(module, required_size, check_largest): + nonlocal offload_stream + nonlocal cast_buffer + + if offload_stream is None: + offload_stream = comfy.model_management.get_offload_stream(device) + if offload_stream is None or not check_largest or len(comfy_modules) != 1: + return + + current_size = 0 if cast_buffer is None else cast_buffer.size() + if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = None + if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: + comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size) + + def get_cast_buffer(buffer_size): + nonlocal offload_stream + nonlocal cast_buffer + nonlocal cast_buffer_offset + + if buffer_size == 0: + return None + + if offload_stream is None: + return torch.empty((buffer_size,), dtype=torch.uint8, device=device) + + cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device) + cast_buffer_offset += buffer_size + return buffer + + for s in comfy_modules: + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + prefetch = { + "signature": signature, + "resident": resident, + } - signature = comfy_aimdo.model_vbar.vbar_fault(s._v) - resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) - if signature is not None: if resident: - weight = s._v_weight - bias = s._v_bias - else: - xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + s._prefetch = prefetch + continue - if not resident: materialize_meta_param(s, ["weight", "bias"]) + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_dest = None + needs_cast = False xfer_source = [ s.weight, s.bias ] @@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu if data is None: continue if data.dtype != geometry.dtype: + needs_cast = True cast_dest = xfer_dest - if cast_dest is None: - cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) xfer_dest = None break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) - offload_stream = comfy.model_management.get_offload_stream(device) - if xfer_dest is None and offload_stream is not None: - xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) - if xfer_dest is None: - offload_stream = comfy.model_management.get_offload_stream(device) - xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True) if xfer_dest is None: - xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) - offload_stream = None + xfer_dest = get_cast_buffer(dest_size) if signature is None and pin is None: comfy.pinned_memory.pin_memory(s) @@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu xfer_source = [ pin ] #send it over comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) - comfy.model_management.sync_stream(device, offload_stream) - if cast_dest is not None: + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + ensure_offload_stream(s, cast_buffer_offset, False) + lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) + + prefetch["xfer_dest"] = xfer_dest + prefetch["cast_dest"] = cast_dest + prefetch["cast_geometry"] = cast_geometry + prefetch["needs_cast"] = needs_cast + s._prefetch = prefetch + + return offload_stream + + +def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant): + + prefetch = getattr(s, "_prefetch", None) + + if prefetch["resident"]: + weight = s._v_weight + bias = s._v_bias + else: + xfer_dest = prefetch["xfer_dest"] + if prefetch["needs_cast"]: + cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device) for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), - comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)): if post_cast is not None: post_cast.copy_(pre_cast) xfer_dest = cast_dest - params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest) weight = params[0] bias = params[1] - if signature is not None: + if prefetch["signature"] is not None: s._v_weight = weight s._v_bias = bias - s._v_signature=signature + s._v_signature = prefetch["signature"] def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) fns = getattr(s, param_key + "_function", []) + if x is None: + return None + orig = x def to_dequant(tensor, dtype): @@ -205,14 +248,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu x = f(x) return x - update_weight = signature is not None + update_weight = prefetch["signature"] is not None + weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight) + if bias is not None: + bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight) - weight = post_cast(s, "weight", weight, dtype, resident, update_weight) - if s.bias is not None: - bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) + if prefetch["signature"] is not None: + prefetch["resident"] = True - #FIXME: weird offload return protocol - return weight, bias, (offload_stream, device if signature is not None else None, None) + return weight, bias def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False): @@ -230,10 +274,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + def format_return(result, offloadable): + weight, bias, offload_stream = result + return (weight, bias, offload_stream) if offloadable else (weight, bias) + non_blocking = comfy.model_management.device_supports_non_blocking(device) if hasattr(s, "_v"): - return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) + + #vbar doesn't support CPU weights, but some custom nodes have weird paths + #that might switch the layer to the CPU and expect it to work. We have to take + #a clone conservatively as we are mmapped and some SFT files are packed misaligned + #If you are a custom node author reading this, please move your layer to the GPU + #or declare your ModelPatcher as CPU in the first place. + if comfy.model_management.is_device_cpu(device): + materialize_meta_param(s, ["weight", "bias"]) + weight = s.weight.to(dtype=dtype, copy=True) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None + return format_return((weight, bias, (None, None, None)), offloadable) + + prefetched = hasattr(s, "_prefetch") + offload_stream = None + offload_device = None + if not prefetched: + offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking) + comfy.model_management.sync_stream(device, offload_stream) + + weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant) + + if not prefetched: + if getattr(s, "_prefetch")["signature"] is not None: + offload_device = device + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + lowvram_fn.clear_prepared() + delattr(s, "_prefetch") + return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable) + if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): @@ -280,11 +360,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of for f in s.weight_function: weight = f(weight) - if offloadable: - return weight, bias, (offload_stream, weight_a, bias_a) - else: - #Legacy function signature - return weight, bias + return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable) def uncast_bias_weight(s, weight, bias, offload_stream): @@ -1173,6 +1249,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self._buffers[key] = fn(buf) return self + class Embedding(manual_cast.Embedding): + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + weight_key = f"{prefix}weight" + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + # Only fp8 makes sense for embeddings (per-row dequant via index select). + # Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently. + quant_format = layer_conf.get("format", None) if layer_conf is not None else None + if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict: + self.quant_format = quant_format + qconfig = QUANT_ALGOS[quant_format] + layout_cls = get_layout_class(qconfig["comfy_tensor_layout"]) + weight = state_dict.pop(weight_key) + manually_loaded_keys = [weight_key] + + scale_key = f"{prefix}weight_scale" + scale = state_dict.pop(scale_key, None) + if scale is not None: + scale = scale.float() + manually_loaded_keys.append(scale_key) + + params = layout_cls.Params( + scale=scale if scale is not None else torch.ones((), dtype=torch.float32), + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.num_embeddings, self.embedding_dim), + ) + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params), + requires_grad=False) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for k in manually_loaded_keys: + if k in missing_keys: + missing_keys.remove(k) + else: + if layer_conf is not None: + state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + def state_dict(self, *args, destination=None, prefix="", **kwargs): + if destination is not None: + sd = destination + else: + sd = {} + + if not hasattr(self, 'weight') or self.weight is None: + return sd + + if isinstance(self.weight, QuantizedTensor): + sd_out = self.weight.state_dict("{}weight".format(prefix)) + for k in sd_out: + sd[k] = sd_out[k] + + quant_conf = {"format": self.quant_format} + sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) + else: + sd["{}weight".format(prefix)] = self.weight + return sd + + def forward_comfy_cast_weights(self, input, out_dtype=None): + weight = self.weight + + # Optimized path: lookup in fp8, dequantize only the selected rows. + if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0: + qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True) + if isinstance(qdata, QuantizedTensor): + scale = qdata._params.scale + qdata = qdata._qdata + else: + scale = None + + x = torch.nn.functional.embedding( + input, qdata, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + uncast_bias_weight(self, qdata, None, offload_stream) + target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype + x = x.to(dtype=target_dtype) + if scale is not None and scale != 1.0: + x = x * scale.to(dtype=target_dtype) + return x + + # Fallback for non-quantized or weight_function (LoRA) case + return super().forward_comfy_cast_weights(input, out_dtype=out_dtype) + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 42ee08fb2..b90bcfd25 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,6 +1,8 @@ import torch import logging +from comfy.cli_args import args + try: import comfy_kitchen as ck from comfy_kitchen.tensor import ( @@ -21,7 +23,15 @@ try: ck.registry.disable("cuda") logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") - ck.registry.disable("triton") + if args.enable_triton_backend: + try: + import triton + logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__) + except ImportError as e: + logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.") + ck.registry.disable("triton") + else: + ck.registry.disable("triton") for k, v in ck.list_backends().items(): logging.info(f"Found comfy_kitchen backend {k}: {v}") except ImportError as e: diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index ab7cf14fa..e54be98d6 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -3,6 +3,7 @@ import comfy.model_management RMSNorm = torch.nn.RMSNorm +# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding). def rms_norm(x, weight=None, eps=1e-6): if weight is None: return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index bbba09e26..3782fd2d5 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -89,7 +89,8 @@ def get_additional_models(conds, dtype): gligen += get_models_from_cond(conds[k], "gligen") add_models += get_models_from_cond(conds[k], "additional_models") - control_nets = set(cnets) + # Order-preserving dedup. A plain set() would randomize iteration order across runs + control_nets = list(dict.fromkeys(cnets)) inference_memory = 0 control_models = [] diff --git a/comfy/sd.py b/comfy/sd.py index ee66490f5..9fce0e7d0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -65,6 +65,7 @@ import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.qwen35 import comfy.text_encoders.ernie +import comfy.text_encoders.gemma4 import comfy.model_patcher import comfy.lora @@ -1271,6 +1272,9 @@ class TEModel(Enum): QWEN35_9B = 26 QWEN35_27B = 27 MINISTRAL_3_3B = 28 + GEMMA_4_E4B = 29 + GEMMA_4_E2B = 30 + GEMMA_4_31B = 31 def detect_te_model(sd): @@ -1296,6 +1300,12 @@ def detect_te_model(sd): return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.59.self_attn.q_norm.weight' in sd: + return TEModel.GEMMA_4_31B + if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd: + return TEModel.GEMMA_4_E4B + if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd: + return TEModel.GEMMA_4_E2B if 'model.layers.47.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: @@ -1435,6 +1445,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer + elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B): + variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B, + TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B, + TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model] + clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant) + clip_target.tokenizer = variant.tokenizer + tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.GEMMA_2_2B: clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index e6c17fb98..dff40461f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1167,6 +1167,25 @@ class WAN21_T2V(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect)) +class WAN21_CausalAR_T2V(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "t2v", + "causal_ar": True, + } + + sampling_settings = { + "shift": 5.0, + } + + def __init__(self, unet_config): + super().__init__(unet_config) + self.unet_config.pop("causal_ar", None) + + def get_model(self, state_dict, prefix="", device=None): + return model_base.WAN21_CausalAR(self, device=device) + + class WAN21_I2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1929,6 +1948,7 @@ models = [ ZImage, Lumina2, WAN22_T2V, + WAN21_CausalAR_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py new file mode 100644 index 000000000..f050061ed --- /dev/null +++ b/comfy/text_encoders/gemma4.py @@ -0,0 +1,1298 @@ +import torch +import torch.nn as nn +import numpy as np +from dataclasses import dataclass +import math + +from comfy import sd1_clip +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.rmsnorm import rms_norm +from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding + + +# Intentional minor divergences from transformers -reference implementation: +# - Embedding sqrt(hidden_size) scale applied as a Python scalar (full precision) instead of dtype-matched buffer tensor. +# - RMSNorm uses torch fused F.rms_norm, very slight numerical differences, but considerably faster +# - Input image and audio resizing/resampling slightly different numerically + + +GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} +GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} +GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5} + +@dataclass +class Gemma4Config: + vocab_size: int = 262144 + hidden_size: int = 2560 + intermediate_size: int = 10240 + num_hidden_layers: int = 42 + num_attention_heads: int = 8 + num_key_value_heads: int = 2 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-6 + rope_theta = [1000000.0, 10000.0] + transformer_type: str = "gemma4" + head_dim = 256 + global_head_dim = 512 + rms_norm_add = False + mlp_activation = "gelu_pytorch_tanh" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + sliding_attention = [512, 512, 512, 512, 512, False] + rope_scale = None + partial_rotary_factor: float = 0.25 + final_norm: bool = True + lm_head: bool = False + final_logit_softcapping: float = 30.0 + hidden_size_per_layer_input: int = 256 + num_kv_shared_layers: int = 18 + use_double_wide_mlp: bool = False + stop_tokens = [1, 50, 106] + vision_config = GEMMA4_VISION_CONFIG + audio_config = GEMMA4_AUDIO_CONFIG + mm_tokens_per_image = 280 + +@dataclass +class Gemma4_E2B_Config(Gemma4Config): + hidden_size: int = 1536 + intermediate_size: int = 6144 + num_hidden_layers: int = 35 + num_key_value_heads: int = 1 + sliding_attention = [512, 512, 512, 512, False] + num_kv_shared_layers: int = 20 + use_double_wide_mlp: bool = True + +@dataclass +class Gemma4_31B_Config(Gemma4Config): + hidden_size: int = 5376 + intermediate_size: int = 21504 + num_hidden_layers: int = 60 + num_attention_heads: int = 32 + num_key_value_heads: int = 16 + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + hidden_size_per_layer_input: int = 0 + num_kv_shared_layers: int = 0 + audio_config = None + vision_config = GEMMA4_VISION_31B_CONFIG + + +# unfused RoPE as addcmul_ RoPE diverges from reference code +def _apply_rotary_pos_emb(x, freqs_cis): + cos, sin = freqs_cis[0], freqs_cis[1] + half = x.shape[-1] // 2 + out = x * cos + out[..., :half] -= x[..., half:] * sin[..., :half] + out[..., half:] += x[..., :half] * sin[..., half:] + return out + +class Gemma4Attention(nn.Module): + def __init__(self, config, head_dim, device=None, dtype=None, ops=None): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.hidden_size = config.hidden_size + self.head_dim = head_dim + self.inner_size = self.num_heads * head_dim + + self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype) + self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) + + self.q_norm = None + self.k_norm = None + if config.q_norm == "gemma3": + self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype) + if config.k_norm == "gemma3": + self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + freqs_cis=None, + past_key_value=None, + sliding_window=None, + shared_kv=None, + ): + batch_size, seq_length, _ = hidden_states.shape + + xq = self.q_proj(hidden_states) + xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) + if self.q_norm is not None: + xq = self.q_norm(xq) + + if shared_kv is not None: + xk, xv = shared_kv + # Apply RoPE to Q only (K already has RoPE from source layer) + xq = _apply_rotary_pos_emb(xq, freqs_cis) + present_key_value = None + shareable_kv = None + else: + xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + if self.k_norm is not None: + xk = self.k_norm(xk) + xv = rms_norm(xv) + xk = xk.transpose(1, 2) + xv = xv.transpose(1, 2) + xq = _apply_rotary_pos_emb(xq, freqs_cis) + xk = _apply_rotary_pos_emb(xk, freqs_cis) + + present_key_value = None + if past_key_value is not None: + cumulative_len = 0 + if len(past_key_value) > 0: + past_key, past_value, cumulative_len = past_key_value + xk = torch.cat((past_key, xk), dim=2) + xv = torch.cat((past_value, xv), dim=2) + new_cumulative = cumulative_len + seq_length + if sliding_window is not None and xk.shape[2] > sliding_window - 1: + cache_k = xk[:, :, -(sliding_window - 1):] + cache_v = xv[:, :, -(sliding_window - 1):] + else: + cache_k = xk + cache_v = xv + present_key_value = (cache_k, cache_v, new_cumulative) + + # KV for sharing: full xk/xv that SDPA sees (not evicted cache) + shareable_kv = (xk, xv) + + # GQA: pass unexpanded KV with enable_gqa when no sliding mask, + # expand heads when sliding mask is present + # has to be done within SDPA itself to match the reference code, pre-scaling expansion causes numerical differences + expand_kv = (self.num_heads != self.num_kv_heads and + sliding_window is not None and + xk.shape[2] >= sliding_window) + if expand_kv: + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + gqa_kwargs = {} if expand_kv else ({"enable_gqa": True} if self.num_heads != self.num_kv_heads else {}) + output = optimized_attention_for_device(xq.device, mask=attention_mask is not None, small_input=True)(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, scale=1.0, **gqa_kwargs) + + return self.o_proj(output), present_key_value, shareable_kv + + +class TransformerBlockGemma4(nn.Module): + def __init__(self, config, index, device=None, dtype=None, ops=None): + super().__init__() + if config.sliding_attention is not None: + self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] + else: + self.sliding_attention = False + + head_dim = config.head_dim if self.sliding_attention else config.global_head_dim + + self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops) + + num_kv_shared = config.num_kv_shared_layers + first_kv_shared = config.num_hidden_layers - num_kv_shared + mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + if self.hidden_size_per_layer_input: + self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) + self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) + self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype)) + else: + self.layer_scalar = None + + def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None): + sliding_window = None + if self.sliding_attention: + sliding_window = self.sliding_attention + # For prefill > sliding window, add sliding window restriction to the causal mask. + if x.shape[1] > self.sliding_attention: + sw_mask = torch.zeros(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) + sw_mask.masked_fill_(torch.ones_like(sw_mask, dtype=torch.bool).tril_(-self.sliding_attention), torch.finfo(x.dtype).min) + attention_mask = attention_mask + sw_mask if attention_mask is not None else sw_mask + freqs_cis = freqs_cis[1] + else: + freqs_cis = freqs_cis[0] + + residual = x + x = self.input_layernorm(x) + x, present_key_value, shareable_kv = self.self_attn( + hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, + past_key_value=past_key_value, sliding_window=sliding_window, shared_kv=shared_kv, + ) + x = self.post_attention_layernorm(x) + x = residual + x + + residual = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + x = residual + x + + if self.hidden_size_per_layer_input and per_layer_input is not None: + residual = x + x = self.per_layer_input_gate(x) + x = torch.nn.functional.gelu(x, approximate="tanh") + x = x * per_layer_input + x = self.per_layer_projection(x) + x = self.post_per_layer_input_norm(x) + x = residual + x + + if self.layer_scalar is not None: + x = x * self.layer_scalar + + return x, present_key_value, shareable_kv + + +class Gemma4Transformer(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.config = config + + self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) + + self.layers = nn.ModuleList([ + TransformerBlockGemma4(config, index=i, device=device, dtype=dtype, ops=ops) + for i in range(config.num_hidden_layers) + ]) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) if config.final_norm else None + + # Precompute RoPE inv_freq on CPU to match reference code's exact value + rope_angles_global = int(config.partial_rotary_factor * config.global_head_dim // 2) + nope_global = config.global_head_dim // 2 - rope_angles_global + global_inv = 1.0 / (config.rope_theta[0] ** (torch.arange(0, 2 * rope_angles_global, 2).float() / config.global_head_dim)) + if nope_global > 0: + global_inv = torch.cat([global_inv, torch.zeros(nope_global)]) + self.register_buffer("_global_inv_freq", global_inv, persistent=False) + + sliding_inv = 1.0 / (config.rope_theta[1] ** (torch.arange(0, config.head_dim, 2).float() / config.head_dim)) + self.register_buffer("_sliding_inv_freq", sliding_inv, persistent=False) + + # Per-layer input mechanism + self.hidden_size_per_layer_input = config.hidden_size_per_layer_input + if self.hidden_size_per_layer_input: + self.embed_tokens_per_layer = _make_scaled_embedding(ops, config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, self.hidden_size_per_layer_input ** 0.5, device, dtype) + self.per_layer_model_projection = ops.Linear( + config.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input, + bias=False, device=device, dtype=dtype) + self.per_layer_projection_norm = RMSNorm( + self.hidden_size_per_layer_input, eps=config.rms_norm_eps, + device=device, dtype=dtype) + + def get_past_len(self, past_key_values): + for kv in past_key_values: + if len(kv) >= 3: + return kv[2] + return 0 + + def _freqs_from_inv(self, inv_freq, position_ids, device, dtype): + """Compute cos/sin from stored inv_freq""" + inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(device) + pos_exp = position_ids[:, None, :].float() + freqs = (inv_exp @ pos_exp).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos().unsqueeze(1).to(dtype), emb.sin().unsqueeze(1).to(dtype) + + def compute_freqs_cis(self, position_ids, device, dtype=None): + global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, device, dtype) + sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, device, dtype) + return [global_freqs, sliding_freqs] + + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, + final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=None, + past_key_values=None, input_ids=None): + if embeds is not None: + x = embeds + else: + x = self.embed_tokens(x, out_dtype=dtype) + + seq_len = x.shape[1] + past_len = 0 + if past_key_values is not None and len(past_key_values) > 0: + past_len = self.get_past_len(past_key_values) + + if position_ids is None: + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0) + + freqs_cis = self.compute_freqs_cis(position_ids, x.device, dtype=x.dtype) + + mask = None + min_val = torch.finfo(x.dtype).min + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), min_val) + + if seq_len > 1: + causal_mask = torch.zeros(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device) + causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val) + mask = mask + causal_mask if mask is not None else causal_mask + + # Per-layer inputs + per_layer_inputs = None + if self.hidden_size_per_layer_input: + num_layers = self.config.num_hidden_layers + hpl = self.hidden_size_per_layer_input + per_layer_proj = self.per_layer_model_projection(x) * (1.0 / (self.config.hidden_size ** 0.5)) + per_layer_proj = self.per_layer_projection_norm(per_layer_proj.reshape(*x.shape[:-1], num_layers, hpl)) + if input_ids is not None and input_ids.shape[1] == x.shape[1]: + per_layer_emb = self.embed_tokens_per_layer(input_ids).reshape(*input_ids.shape, num_layers, hpl) + per_layer_inputs = (per_layer_proj + per_layer_emb) * (0.5 ** 0.5) + else: + per_layer_inputs = per_layer_proj + + # KV sharing: later layers reuse KV from the last non-shared sliding/global layer + num_kv_shared = self.config.num_kv_shared_layers + first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers + shared_sliding_kv = None # KV from last non-shared sliding layer + shared_global_kv = None # KV from last non-shared global layer + + intermediate = None + next_key_values = [] + for i, layer in enumerate(self.layers): + past_kv = past_key_values[i] if past_key_values is not None and len(past_key_values) > 0 else None + + layer_kwargs = {} + if per_layer_inputs is not None: + layer_kwargs['per_layer_input'] = per_layer_inputs[:, :, i, :] + + is_sliding = hasattr(layer, 'sliding_attention') and layer.sliding_attention + if i >= first_kv_shared and num_kv_shared > 0: + shared = shared_sliding_kv if is_sliding else shared_global_kv + if shared is not None: + layer_kwargs['shared_kv'] = shared + + x, current_kv, shareable_kv = layer(x=x, attention_mask=mask, freqs_cis=freqs_cis, past_key_value=past_kv, **layer_kwargs) + + next_key_values.append(current_kv if current_kv is not None else ()) + + # Only track the last sliding/global before the sharing boundary + if i < first_kv_shared and shareable_kv is not None: + if is_sliding: + shared_sliding_kv = shareable_kv + else: + shared_global_kv = shareable_kv + + if i == intermediate_output: + intermediate = x.clone() + + if self.norm is not None: + x = self.norm(x) + + if len(next_key_values) > 0: + return x, intermediate, next_key_values + return x, intermediate + + +class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module): + """Common base for all Gemma4 variants: text model + vision.""" + def _init_model(self, config, dtype, device, operations): + self.num_layers = config.num_hidden_layers + self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype=dtype, device=device, ops=operations) + self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype=dtype, device=device, ops=operations) + + def logits(self, x): + logits = super().logits(x) + cap = self.model.config.final_logit_softcapping + if cap: + logits = cap * torch.tanh(logits / cap) + return logits + + def init_kv_cache(self, batch, max_cache_len, device, execution_dtype): + past_key_values = [] + for _ in range(self.model.config.num_hidden_layers): + past_key_values.append(()) + return past_key_values + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = embed.pop("data").movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W] + max_soft_tokens = embed.get("max_soft_tokens", None) + vision_out = self.vision_model(image.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens) + return self.multi_modal_projector(vision_out), None + return None, None + + +class Gemma4AudioMixin: + """Adds audio support to a Gemma4 model.""" + def _init_audio(self, config, dtype, device, operations): + self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype=dtype, device=device, ops=operations) + self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype=dtype, device=device, ops=operations) + + def preprocess_embed(self, embed, device): + result, extra = super().preprocess_embed(embed, device) + if result is not None: + return result, extra + if embed["type"] == "audio": + audio = embed.pop("data").to(device, dtype=torch.float32) + audio_mask = embed.pop("mask", None) + if audio_mask is not None: + audio_mask = audio_mask.to(device) + audio_out = self.audio_model(audio, audio_mask=audio_mask) + return self.audio_projector(audio_out), None + return None, None + + +# Vision Encoder + +def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None): + """Compute 2D RoPE for vision: separate frequencies for x and y dimensions. + + Args: + head_dim: dimension per head (e.g. 64) + pixel_position_ids: [batch, num_patches, 2] with (x, y) coords + theta: RoPE base frequency + Returns: + (cos, sin) each of shape [batch, num_patches, head_dim] + """ + rotary_dim_per_axis = head_dim // 2 + freq_indices = torch.arange(0, rotary_dim_per_axis, 2, device=device).float() + inv_freq = 1.0 / (theta ** (freq_indices / rotary_dim_per_axis)) + + all_cos, all_sin = [], [] + for i in range(2): # x and y + dim_positions = pixel_position_ids[:, :, i].float() # [batch, num_patches] + freqs = torch.einsum('bi,j->bij', dim_positions, inv_freq.to(device)) # [batch, num_patches, rotary_dim/2] + emb = torch.cat([freqs, freqs], dim=-1) # [batch, num_patches, rotary_dim] + all_cos.append(emb.cos()) + all_sin.append(emb.sin()) + + cos = torch.cat(all_cos, dim=-1).to(pixel_position_ids.device) # [batch, num_patches, head_dim] + sin = torch.cat(all_sin, dim=-1).to(pixel_position_ids.device) + return cos, sin + + +def _apply_vision_2d_rope(x, freqs): + """Apply 2D RoPE (multidimensional) to vision query/key states. + + Splits x and cos/sin into ndim=2 parts, applies 1D RoPE to each independently. + + x: [batch, heads, seq, head_dim] + freqs: (cos, sin) each [batch, seq, head_dim] + """ + cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim] + sin = freqs[1].unsqueeze(1) + half = x.shape[-1] // 2 + a = _apply_rotary_pos_emb(x[..., :half], (cos[..., :half], sin[..., :half])) + b = _apply_rotary_pos_emb(x[..., half:], (cos[..., half:], sin[..., half:])) + return torch.cat([a, b], dim=-1) + + +class ClippedLinear(nn.Module): + """Linear layer with activation clipping (from quantization-aware training). + + Stores input_max/min and output_max/min as buffers loaded from checkpoint. + """ + def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, ops=None): + super().__init__() + self.linear = ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) + self.register_buffer('input_max', torch.tensor(float('inf'), device=device, dtype=dtype)) + self.register_buffer('input_min', torch.tensor(float('-inf'), device=device, dtype=dtype)) + self.register_buffer('output_max', torch.tensor(float('inf'), device=device, dtype=dtype)) + self.register_buffer('output_min', torch.tensor(float('-inf'), device=device, dtype=dtype)) + + @property + def weight(self): + return self.linear.weight + + def forward(self, x): + x = x.clamp(min=self.input_min, max=self.input_max) + x = self.linear(x) + return x.clamp_(min=self.output_min, max=self.output_max) + + +class Gemma4VisionMLP(nn.Module): + """SwiGLU MLP matching gate_proj/up_proj/down_proj structure.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + hidden_size = config["hidden_size"] + intermediate_size = config["intermediate_size"] + self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) + + def forward(self, x): + return self.down_proj(torch.nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x)) + + +class Gemma4VisionAttention(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.num_heads = config["num_attention_heads"] + self.head_dim = config.get("head_dim", self.hidden_size // self.num_heads) + + self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops) + self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops) + + self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype) + + def forward(self, x, freqs, attention_mask=None): + batch_size, seq_length, _ = x.shape + + xq = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) + xk = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) + xv = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim) + + xq = self.q_norm(xq).transpose(1, 2) + xk = self.k_norm(xk).transpose(1, 2) + xv = rms_norm(xv) + + xq = _apply_vision_2d_rope(xq, freqs) + xk = _apply_vision_2d_rope(xk, freqs) + + xv = xv.to(xq.dtype).transpose(1, 2) + + output = optimized_attention_for_device(xq.device, mask=attention_mask is not None, small_input=True)(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, scale=1.0) + return self.o_proj(output) + + +class Gemma4VisionLayer(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops) + self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops) + norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype) + hidden = config["hidden_size"] + self.input_layernorm = RMSNorm(hidden, **norm_kwargs) + self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs) + self.pre_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs) + self.post_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs) + + def forward(self, x, freqs, attention_mask=None): + residual = x + x = self.input_layernorm(x) + x = self.self_attn(x, freqs, attention_mask=attention_mask) + x = self.post_attention_layernorm(x) + x = residual + x + + residual = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + x = residual + x + return x + + +class Gemma4PatchEmbedder(nn.Module): + """Patch embedding with learned 2D position embeddings via one-hot lookup.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + hidden_size = config["hidden_size"] + patch_size = config["patch_size"] + self.patch_size = patch_size + self.position_embedding_size = config.get("position_embedding_size", 10240) + + self.input_proj = ops.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype) + self.position_embedding_table = nn.Parameter( + torch.empty(2, self.position_embedding_size, hidden_size, device=device, dtype=dtype) + ) + + def forward(self, patches, pixel_position_ids): + """ + patches: [B, num_patches, 3*patch_size²] in [0,1] range (normalized to [-1,1] inside, matching HF) + pixel_position_ids: [B, num_patches, 2] with (x,y) positions, (-1,-1) for padding + """ + hidden_states = self.input_proj((2.0 * (patches - 0.5)).to(self.input_proj.weight.dtype)) + + clamped_positions = pixel_position_ids.clamp(min=0) + pos_table = comfy.model_management.cast_to_device(self.position_embedding_table, hidden_states.device, hidden_states.dtype) + position_embeddings = pos_table[0][clamped_positions[..., 0]] + pos_table[1][clamped_positions[..., 1]] + + # Zero out position embeddings for padding patches (matching HF) + padding_positions = (pixel_position_ids == -1).all(dim=-1) + position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings) + + return hidden_states + position_embeddings + + +class Gemma4VisionEncoderLayers(nn.Module): + """Wrapper to produce state dict keys as encoder.layers.X.*""" + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__() + self.layers = nn.ModuleList([ + Gemma4VisionLayer(config, device=device, dtype=dtype, ops=ops) + for _ in range(config["num_hidden_layers"]) + ]) + + +class Gemma4VisionEncoder(nn.Module): + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__() + self.config = config + self.hidden_size = config["hidden_size"] + self.head_dim = config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]) + self.patch_size = config["patch_size"] + self.pooling_kernel_size = config.get("pooling_kernel_size", 3) + self.root_hidden_size = self.hidden_size ** 0.5 + + self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, ops=ops) + self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, ops=ops) + + def forward(self, pixel_values, max_soft_tokens=None): + """ + pixel_values: [B, C, H, W] in [0,1] range + max_soft_tokens: if provided, pad to max_soft_tokens * k² total patches + """ + batch_size, _, height, width = pixel_values.shape + ps = self.patch_size + k = self.pooling_kernel_size + patches_h, patches_w = height // ps, width // ps + num_patches = patches_h * patches_w + output_length = max_soft_tokens if max_soft_tokens is not None else num_patches // (k * k) + n_padding = output_length * k * k - num_patches + + # Patchify and build position grid + patches = pixel_values.reshape(batch_size, -1, patches_h, ps, patches_w, ps) + patches = patches.permute(0, 2, 4, 3, 5, 1).reshape(batch_size, num_patches, -1) + grid_y, grid_x = torch.meshgrid(torch.arange(patches_h, device=pixel_values.device), torch.arange(patches_w, device=pixel_values.device), indexing='ij') + position_ids = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1).unsqueeze(0).expand(batch_size, -1, -1) + + # Append zero-pixel padding with (-1,-1) positions + if n_padding > 0: + patches = torch.cat([patches, patches.new_zeros(batch_size, n_padding, patches.shape[-1])], dim=1) + position_ids = torch.cat([position_ids, position_ids.new_full((batch_size, n_padding, 2), -1)], dim=1) + + padding = (position_ids == -1).all(dim=-1) + + # Embed, encode, pool + x = self.patch_embedder(patches, position_ids) + freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device) + freqs = tuple(t.to(x.dtype) for t in freqs) + if n_padding > 0: + mask = padding.unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1) + mask = torch.zeros_like(mask, dtype=x.dtype).masked_fill_(mask, torch.finfo(x.dtype).min) + else: + mask = None + + for layer in self.encoder.layers: + x = layer(x, freqs, attention_mask=mask) + + if n_padding > 0: + x = x.masked_fill(padding.unsqueeze(-1), 0.0) + + # Average pool by spatial position + clamped = position_ids.clamp(min=0) + max_x = clamped[:, :, 0].max(dim=-1, keepdim=True)[0] + 1 + ki = torch.div(clamped, k, rounding_mode="floor") + ki = ki[:, :, 0] + (max_x // k) * ki[:, :, 1] + weights = torch.nn.functional.one_hot(ki.long(), output_length).float() / (k * k) + x = (weights.transpose(1, 2) @ x.float()).to(x.dtype) + + # Strip empty output tokens + valid_out = ~((weights == 0).all(dim=1)) + if valid_out.any() and not valid_out.all(): + x = x[:, valid_out[0]] if batch_size > 1 else x[valid_out].unsqueeze(0) + + return x * self.root_hidden_size + + +class Gemma4RMSNormProjector(nn.Module): + """Shared projector: parameterless RMSNorm → linear. Used for both vision and audio.""" + def __init__(self, in_dim, out_dim, dtype=None, device=None, ops=None): + super().__init__() + self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x): + return self.embedding_projection(rms_norm(x)) + + +class Gemma4MultiModalProjector(Gemma4RMSNormProjector): + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops) + + +# Audio Encoder + +class Gemma4AudioConvSubsampler(nn.Module): + """2D convolution subsampling for audio features""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + eps = config["rms_norm_eps"] + self.layer0 = nn.ModuleDict({ + 'conv': ops.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), + 'norm': ops.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), + }) + self.layer1 = nn.ModuleDict({ + 'conv': ops.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype), + 'norm': ops.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype), + }) + # proj_input_dim = (128 // 4) * 32 = 1024 + self.input_proj_linear = ops.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype) + + def _conv_layer(self, x, layer, mask): + if mask is not None: + x = x * mask[:, None, :, None].to(x.device) + x = layer['conv'](x.to(layer['conv'].weight.dtype)) + x = torch.relu(layer['norm'](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous()) + if mask is not None: + mask = mask[:, ::2] + return x, mask + + def forward(self, x, mask=None): + x = x.unsqueeze(1) + x, mask = self._conv_layer(x, self.layer0, mask) + x, mask = self._conv_layer(x, self.layer1, mask) + batch_size, _, seq_len, _ = x.shape + x = x.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1) + return self.input_proj_linear(x), mask + + +class Gemma4AudioFeedForward(nn.Module): + """Conformer feed-forward with residual scaling.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + hidden_size = config["hidden_size"] + intermediate_size = config.get("intermediate_size", hidden_size * 4) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops) + self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.post_layer_scale = config.get("residual_weight", 0.5) + + def forward(self, x): + residual = x + x = self.pre_layer_norm(x) + x = torch.nn.functional.silu(self.ffw_layer_1(x)) + x = self.ffw_layer_2(x) + x = self.post_layer_norm(x) + x = x * self.post_layer_scale + return x + residual + + +class Gemma4AudioRelPositionalEncoding(nn.Module): + """Sinusoidal relative positional encoding for audio attention.""" + def __init__(self, config, device=None, dtype=None): + super().__init__() + hidden_size = config["hidden_size"] + context_left = config.get("attention_context_left", 13) + context_right = config.get("attention_context_right", 0) + self.chunk_size = config.get("attention_chunk_size", 12) + self.context_size = self.chunk_size + context_left - 1 + context_right + + num_timescales = hidden_size // 2 + log_inc = math.log(10000.0) / max(num_timescales - 1, 1) + inv_timescales = torch.exp(torch.arange(num_timescales) * -log_inc).to(dtype=dtype).unsqueeze(0).unsqueeze(0) + self.register_buffer("inv_timescales", inv_timescales, persistent=False) + + def forward(self, hidden_states): + positions = torch.arange(self.chunk_size, -1, -1, device=hidden_states.device).unsqueeze(-1) + scaled = positions * self.inv_timescales.to(device=hidden_states.device) + return torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1).to(dtype=hidden_states.dtype) + + +class Gemma4AudioAttention(nn.Module): + """Chunked block attention with relative position bias and softcap.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.num_heads = config["num_attention_heads"] + self.head_dim = self.hidden_size // self.num_heads + self.chunk_size = config.get("attention_chunk_size", 12) + self.max_past_horizon = config.get("attention_context_left", 13) - 1 + self.max_future_horizon = config.get("attention_context_right", 0) + self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon + + self.q_scale = (self.head_dim ** -0.5) / math.log(2) + self.k_scale = math.log(1 + math.e) / math.log(2) + self.register_buffer("softcap", torch.tensor(config.get("attention_logit_cap", 50.0), dtype=dtype), persistent=False) + + self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops) + self.per_dim_scale = nn.Parameter(torch.empty(self.head_dim, device=device, dtype=dtype)) + self.relative_k_proj = ops.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype) + + def _convert_to_block(self, x): + B, S, H, D = x.shape + num_blocks = (S + self.chunk_size - 1) // self.chunk_size + pad = num_blocks * self.chunk_size - S + x = torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad)) + return x.reshape(B, num_blocks, self.chunk_size, H, D).contiguous() + + def _extract_block_context(self, x): + x = torch.nn.functional.pad(x, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1)) + x = x.unfold(1, self.context_size, self.chunk_size) + return torch.movedim(x, -1, 2).contiguous() + + def _rel_shift(self, x): + B, H, NB, BS, PL = x.shape + CS = self.context_size + x = torch.nn.functional.pad(x, (0, CS + 1 - PL)) + x = x.view(B, H, NB, BS * (CS + 1)) + x = x[..., :BS * CS] + return x.view(B, H, NB, BS, CS) + + def _build_blocked_mask(self, seq_len, num_blocks, device, audio_mask=None): + """Build 5D boolean blocked attention mask (True=attend, False=mask)""" + q = torch.arange(seq_len, device=device) + dist = q[:, None] - q[None, :] + mask = (dist >= 0) & (dist < self.max_past_horizon) + if self.max_future_horizon > 0: + mask = mask | ((dist < 0) & ((-dist) < self.max_future_horizon)) + if audio_mask is not None: + mask = mask & audio_mask[0, None, :].bool() + m = mask[None, None] + # Reshape to blocked 5D matching reference code + p = num_blocks * self.chunk_size - seq_len + m = torch.nn.functional.pad(m, (0, p, 0, p), value=False) + m = m.reshape(1, 1, num_blocks, self.chunk_size, -1) + m = torch.nn.functional.pad(m, (self.max_past_horizon, self.max_future_horizon), value=False) + idx = (torch.arange(num_blocks, device=device) * self.chunk_size)[:, None] + torch.arange(self.context_size, device=device)[None, :] + return m.gather(-1, idx[None, None, :, None, :].expand(1, 1, -1, self.chunk_size, -1)) + + def forward(self, x, position_embeddings=None, attn_mask=None): + B, S, _ = x.shape + + q = self.q_proj(x).float().view(B, S, self.num_heads, self.head_dim) + k = self.k_proj(x).float().view(B, S, self.num_heads, self.head_dim) + v = self.v_proj(x).float().view(B, S, self.num_heads, self.head_dim) + + q = q * self.q_scale * torch.nn.functional.softplus(self.per_dim_scale) + k = k * self.k_scale + + q_blocks = self._convert_to_block(q) + k_context = self._extract_block_context(k) + v_context = self._extract_block_context(v) + num_blocks = q_blocks.shape[1] + + rel_k = self.relative_k_proj(position_embeddings).view(-1, self.num_heads, self.head_dim).to(q.dtype) + + queries = q_blocks.permute(0, 3, 1, 2, 4) # [B, H, NB, CS, D] + matrix_ac = queries @ k_context.permute(0, 3, 1, 4, 2) + + queries_flat = queries.reshape(B, self.num_heads, -1, self.head_dim) + matrix_bd = queries_flat @ rel_k.permute(1, 2, 0) + matrix_bd = matrix_bd.reshape(B, self.num_heads, num_blocks, self.chunk_size, -1) + matrix_bd = self._rel_shift(matrix_bd) + + attn_weights = matrix_ac + matrix_bd + attn_weights = torch.tanh(attn_weights / self.softcap) * self.softcap + + # Mask out invalid positions in chunk context (matching reference's masked_fill approach) + if attn_mask is None: + attn_mask = self._build_blocked_mask(S, num_blocks, x.device) + attn_weights = attn_weights.masked_fill(attn_mask.logical_not(), -1e9) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(v.dtype) + out = attn_weights @ v_context.permute(0, 3, 1, 2, 4) + out = out.permute(0, 2, 3, 1, 4).reshape(B, num_blocks * self.chunk_size, -1) + out = out[:, :S].contiguous() + return self.post(out.to(self.post.linear.weight.dtype)) + + +class Gemma4AudioLConv1d(nn.Module): + """Lightweight convolution with standard GLU.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + hidden_size = config["hidden_size"] + conv_kernel_size = config.get("conv_kernel_size", 5) + self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops) + # Causal conv: left-pad only + self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype) + self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1 + self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype) + self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops) + + def forward(self, x): + residual = x + x = self.pre_layer_norm(x) + x = self.linear_start(x) + x = torch.nn.functional.glu(x, dim=-1) + x = x.transpose(1, 2) + x = torch.nn.functional.pad(x, (self.conv_left_pad, 0)) + x = self.depthwise_conv1d(x).transpose(1, 2) + x = self.conv_norm(x) + x = torch.nn.functional.silu(x) + x = self.linear_end(x) + return x + residual + + +class Gemma4AudioLayer(nn.Module): + """Conformer block: FFN1 -> Attention -> LConv -> FFN2.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) + self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops) + norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype) + hidden_size = config["hidden_size"] + self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs) + self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs) + self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, ops=ops) + self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops) + self.norm_out = RMSNorm(hidden_size, **norm_kwargs) + + def forward(self, x, position_embeddings=None, attn_mask=None): + x = self.feed_forward1(x) + + residual = x + x = self.norm_pre_attn(x) + x = self.self_attn(x, position_embeddings=position_embeddings, attn_mask=attn_mask) + x = self.norm_post_attn(x) + x = x + residual + + x = self.lconv1d(x) + x = self.feed_forward2(x) + + x = self.norm_out(x) + return x + + +class Gemma4AudioEncoder(nn.Module): + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__() + self.hidden_size = config["hidden_size"] + self.output_proj_dims = config.get("output_proj_dims", 1536) + + self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, ops=ops) + self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config, device=device, dtype=dtype) + + self.layers = nn.ModuleList([ + Gemma4AudioLayer(config, device=device, dtype=dtype, ops=ops) + for _ in range(config["num_hidden_layers"]) + ]) + + self.output_proj = ops.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype) + + def forward(self, audio_features, audio_mask=None): + x, audio_mask = self.subsample_conv_projection(audio_features, audio_mask) + position_embeddings = self.rel_pos_enc(x) + + # Build blocked attention mask once for all layers + attn_mask = self.layers[0].self_attn._build_blocked_mask( + x.shape[1], (x.shape[1] + self.layers[0].self_attn.chunk_size - 1) // self.layers[0].self_attn.chunk_size, + x.device, audio_mask=audio_mask) + + for layer in self.layers: + x = layer(x, position_embeddings=position_embeddings, attn_mask=attn_mask) + + x = self.output_proj(x) + return x + + +class Gemma4AudioProjector(Gemma4RMSNormProjector): + def __init__(self, config, dtype=None, device=None, ops=None): + super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, ops=ops) + + +# Tokenizer and Wrappers + +class Gemma4_Tokenizer(): + tokenizer_json_data = None + + def state_dict(self): + if self.tokenizer_json_data is not None: + return {"tokenizer_json": self.tokenizer_json_data} + return {} + + def _extract_mel_spectrogram(self, waveform, sample_rate): + """Extract 128-bin log mel spectrogram. + Uses numpy for FFT/matmul/log to produce bit-identical results with reference code. + """ + # Mix to mono first, then resample to 16kHz + if waveform.dim() > 1 and waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + if waveform.dim() == 1: + waveform = waveform.unsqueeze(0) + audio = waveform.squeeze(0).float().numpy() + if sample_rate != 16000: + # Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match) + from scipy.signal import resample_poly, firwin + from math import gcd + g = gcd(sample_rate, 16000) + up, down = 16000 // g, sample_rate // g + L = max(up, down) + h = firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5)) + audio = resample_poly(audio, up, down, window=h).astype(np.float32) + n = len(audio) + + # Pad to multiple of 128, build sample-level mask + if n % 128 != 0: + audio = np.pad(audio, (0, 128 - n % 128)) + mask_raw = np.ones(len(audio), dtype=np.float32) + mask_raw[n:] = 0.0 + + # Semicausal padding: 160 zeros prepended + audio = np.pad(audio, (160, 0)) + mask_raw = np.pad(mask_raw, (160, 0)) + + # Extract 321-sample frames via stride tricks, drop last → 320 + nf = (len(audio) - 321) // 160 + 1 + strides = (audio.strides[0] * 160, audio.strides[0]) + frames = np.lib.stride_tricks.as_strided(audio, (nf, 321), strides)[..., :-1].copy() + + # Periodic Hann window, FFT magnitude, mel filterbank, log + window = (0.5 - 0.5 * np.cos(2 * np.pi * np.arange(320) / 320)).astype(np.float32) + magnitude = np.abs(np.fft.rfft(frames * window, n=512, axis=-1)) + mel_fb = self._build_mel_filterbank() + log_mel = np.log(np.matmul(magnitude, mel_fb) + np.float64(0.001)).astype(np.float32) + + # Frame mask: valid when last sample in window is real audio + mask = mask_raw[np.arange(nf) * 160 + 320].astype(bool) + log_mel = log_mel * mask[:, None] + return torch.from_numpy(log_mel), torch.from_numpy(mask) # [T, 128], [T] + + @staticmethod + def _build_mel_filterbank(): + """Build 128-bin HTK mel filterbank [257, 128] for 512-pt FFT at 16kHz.""" + mel_freqs = np.linspace(0.0, 2595.0 * np.log10(1.0 + 8000.0 / 700.0), 130) + filter_freqs = 700.0 * (10.0 ** (mel_freqs / 2595.0) - 1.0) + fft_freqs = np.linspace(0, 16000 // 2, 257) + filter_diff = np.diff(filter_freqs) + slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1) + down_slopes = -slopes[:, :-2] / filter_diff[:-1] + up_slopes = slopes[:, 2:] / filter_diff[1:] + return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes)) + + def tokenize_with_weights(self, text, return_word_ids=False, image=None, audio=None, video=None, llama_template=None, skip_template=True, thinking=False, **kwargs): + + # Process audio + audio_features = [] + if audio is not None: + waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio + sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000 + mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate) + audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T]) + + # Process image/video frames + is_video = video is not None + source = video if is_video else image + images = [] + if source is not None: + samples = source.movedim(-1, 1) # [B, C, H, W] + num_frames = samples.shape[0] + + # Subsample video to 1fps + if is_video: + fps = kwargs.get("fps", 24) + step = max(1, round(fps)) + indices = list(range(0, num_frames, step)) + if len(indices) == 0: + indices = [0] + samples = samples[indices] + num_frames = len(indices) + + h, w = samples.shape[2], samples.shape[3] + patch_size = 16 + pooling_k = 3 + max_soft_tokens = 70 if is_video else 280 # video uses smaller token budget per frame + max_patches = max_soft_tokens * pooling_k * pooling_k + target_px = max_patches * patch_size * patch_size + factor = (target_px / (h * w)) ** 0.5 + side_mult = pooling_k * patch_size + target_h = max(int(factor * h // side_mult) * side_mult, side_mult) + target_w = max(int(factor * w // side_mult) * side_mult, side_mult) + + import torchvision.transforms.functional as TVF + for i in range(num_frames): + # rescaling to match reference code + s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8 + if target_h != h or target_w != w: + s = TVF.resize(s, [target_h, target_w], interpolation=TVF.InterpolationMode.BICUBIC, antialias=True) + s = s.float() * (1.0 / 255.0) + images.append({"pixels": s.unsqueeze(0).movedim(1, -1)[:, :, :, :3], "max_soft_tokens": max_soft_tokens}) + + if text.startswith('<|turn>'): + skip_template = True + + if skip_template: + llama_text = text + else: + if llama_template is not None: + llama_text = llama_template.format(text) + else: + # Build template from modalities present + system = "<|turn>system\n<|think|>\n" if thinking else "" + media = "" + if len(images) > 0: + if is_video: + media += "\n\n" + for i in range(len(images)): + ts = f"{int(i // 60):02d}:{int(i % 60):02d}" + sep = "" if i == 0 else " " + media += f"{sep}{ts} <|image><|video|>" + media += "\n\n" + else: + media += "\n\n" + for i in range(len(images)): + if i > 0: + media += "\n\n\n\n" + media += "<|image><|image|>" + media += "\n\n" + if len(audio_features) > 0: + # Compute audio token count (always at 16kHz) + num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1] + _fl = 320 # int(round(16000 * 20.0 / 1000.0)) + _hl = 160 # int(round(16000 * 10.0 / 1000.0)) + _nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1 + _t = _nmel + for _ in range(2): + _t = (_t + 2 - 3) // 2 + 1 + n_audio_tokens = min(_t, 750) + media += "<|audio>" + "<|audio|>" * n_audio_tokens + "" + llama_text = f"{system}<|turn>user\n{media}{text}\n<|turn>model\n" + + text_tokens = super().tokenize_with_weights(llama_text, return_word_ids) + + def _replace_placeholders(token_list, token_id, embeds): + """Replace first placeholder with embed dict, remove remaining consecutive ones.""" + embed_idx = 0 + i = 0 + while i < len(token_list): + if token_list[i][0] == token_id and embed_idx < len(embeds): + token_list[i] = (embeds[embed_idx],) + token_list[i][1:] + embed_idx += 1 + i += 1 + while i < len(token_list) and token_list[i][0] == token_id: + token_list.pop(i) + else: + i += 1 + + if len(images) > 0: + img_token_id = 258884 if is_video else 258880 + img_embeds = [{"type": "image", "data": img["pixels"], "max_soft_tokens": img["max_soft_tokens"]} for img in images] + for r in text_tokens: + _replace_placeholders(r, img_token_id, img_embeds) + + if len(audio_features) > 0: + aud_embeds = [{"type": "audio", "data": mel, "mask": mask} for mel, mask in audio_features] + for r in text_tokens: + _replace_placeholders(r, 258881, aud_embeds) + + return text_tokens + + +class _Gemma4Tokenizer: + """Tokenizer using the tokenizers (Gemma4 doesn't come with sentencepiece model)""" + def __init__(self, tokenizer_json_bytes=None, **kwargs): + from tokenizers import Tokenizer + if isinstance(tokenizer_json_bytes, torch.Tensor): + tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist()) + self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8")) + + @classmethod + def from_pretrained(cls, tokenizer_data, **kwargs): + return cls(tokenizer_json_bytes=tokenizer_data, **kwargs) + + def __call__(self, text): + return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids} + + def get_vocab(self): + return self.tokenizer.get_vocab() + + def convert_tokens_to_ids(self, tokens): + return [self.tokenizer.token_to_id(t) for t in tokens] + + def decode(self, ids, **kwargs): + return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False)) + + +# Tokenizer +class Gemma4SDTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer): + embedding_size = 2560 + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_json = tokenizer_data.get("tokenizer_json", None) + self.tokenizer_json_data = tokenizer_json + super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data) + + def decode(self, token_ids, **kwargs): + text = super().decode(token_ids, skip_special_tokens=False) + # Translate thinking channel markers to standard / tags + text = text.replace("<|channel>thought\n", "\n") + text = text.replace("", "") + # Strip remaining special tokens + text = text.replace("", "").replace("", "").strip() + return text + + +class Gemma4Tokenizer(sd1_clip.SD1Tokenizer): + tokenizer_class = Gemma4SDTokenizer + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4", tokenizer=self.tokenizer_class) + + +# Model wrappers +class Gemma4Model(sd1_clip.SDClipModel): + model_class = None + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + self.dtypes = set() + self.dtypes.add(dtype) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=self.model_class, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def process_tokens(self, tokens, device): + embeds, _, _, _ = super().process_tokens(tokens, device) + return embeds + + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0): + if isinstance(tokens, dict): + tokens = next(iter(tokens.values())) + tokens_only = [[t[0] for t in b] for b in tokens] + embeds, _, _, embeds_info = sd1_clip.SDClipModel.process_tokens(self, tokens_only, self.execution_device) + seq_len = embeds.shape[1] + ids = [0] * seq_len + expanded_idx = 0 + embed_map = {info["index"]: info["size"] for info in embeds_info} + for t in tokens_only[0]: + if expanded_idx in embed_map: + expanded_idx += embed_map[expanded_idx] + elif isinstance(t, int): + if expanded_idx < seq_len: + ids[expanded_idx] = t + expanded_idx += 1 + else: + expanded_idx += 1 + initial_token_ids = [ids] + input_ids = torch.tensor(initial_token_ids, device=self.execution_device) + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids) + + +def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=None): + clip_model = type('Gemma4Model_', (Gemma4Model,), {'model_class': model_class}) + class Gemma4TEModel_(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, name="gemma4", clip_model=clip_model, model_options=model_options) + return Gemma4TEModel_ + + +# Variants + +def _make_variant(config_cls): + audio = config_cls.audio_config is not None + bases = (Gemma4AudioMixin, Gemma4Base) if audio else (Gemma4Base,) + class Variant(*bases): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self._init_model(config_cls(**config_dict), dtype, device, operations) + if audio: + self._init_audio(self.model.config, dtype, device, operations) + embedding_size = config_cls.hidden_size + if embedding_size != Gemma4SDTokenizer.embedding_size: + tok_cls = type('T', (Gemma4SDTokenizer,), {'embedding_size': embedding_size}) + class Tokenizer(Gemma4Tokenizer): + tokenizer_class = tok_cls + Variant.tokenizer = Tokenizer + else: + Variant.tokenizer = Gemma4Tokenizer + return Variant + +Gemma4_E4B = _make_variant(Gemma4Config) +Gemma4_E2B = _make_variant(Gemma4_E2B_Config) +Gemma4_31B = _make_variant(Gemma4_31B_Config) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 6ea8e36b1..a34c41144 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -521,7 +521,7 @@ class Attention(nn.Module): else: present_key_value = (xk, xv, index + num_tokens) - if sliding_window is not None and xk.shape[2] > sliding_window: + if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1: xk = xk[:, :, -sliding_window:] xv = xv[:, :, -sliding_window:] attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None @@ -533,12 +533,12 @@ class Attention(nn.Module): return self.o_proj(output), present_key_value class MLP(nn.Module): - def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): + def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None): super().__init__() - ops = ops or nn - self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) - self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) - self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) + intermediate_size = intermediate_size or config.intermediate_size + self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype) + self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype) + self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) if config.mlp_activation == "silu": self.activation = torch.nn.functional.silu elif config.mlp_activation == "gelu_pytorch_tanh": @@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module): return x, present_key_value +def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype): + class ScaledEmbedding(ops.Embedding): + def forward(self, input_ids, out_dtype=None): + return super().forward(input_ids, out_dtype=out_dtype) * scale + return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype) + + class Llama2_(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): super().__init__() self.config = config self.vocab_size = config.vocab_size - self.embed_tokens = ops.Embedding( - config.vocab_size, - config.hidden_size, - device=device, - dtype=dtype - ) if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3": transformer = TransformerBlockGemma2 - self.normalize_in = True + self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) else: transformer = TransformerBlock - self.normalize_in = False + self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.layers = nn.ModuleList([ transformer(config, index=i, device=device, dtype=dtype, ops=ops) @@ -690,15 +691,12 @@ class Llama2_(nn.Module): self.config.rope_dims, device=device) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None): if embeds is not None: x = embeds else: x = self.embed_tokens(x, out_dtype=dtype) - if self.normalize_in: - x *= self.config.hidden_size ** 0.5 - seq_len = x.shape[1] past_len = 0 if past_key_values is not None and len(past_key_values) > 0: @@ -850,7 +848,7 @@ class BaseGenerate: torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) return past_key_values - def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0): + def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None): device = embeds.device if stop_tokens is None: @@ -875,14 +873,16 @@ class BaseGenerate: pbar = comfy.utils.ProgressBar(max_length) # Generation loop + current_input_ids = initial_input_ids for step in tqdm(range(max_length), desc="Generating tokens"): - x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values) + x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids) logits = self.logits(x)[:, -1] next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty) token_id = next_token[0].item() generated_token_ids.append(token_id) embeds = self.model.embed_tokens(next_token).to(execution_dtype) + current_input_ids = next_token if initial_input_ids is not None else None pbar.update(1) if token_id in stop_tokens: diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 5aee1f4c0..bc5cbae28 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel): def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty): tokens_only = [[t[0] for t in b] for b in tokens] - embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device) - comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device) return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is class DualLinearProjection(torch.nn.Module): diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index 01ebdfabe..b1f1dbb9f 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) def process_tokens(self, tokens, device): - embeds, _, _, embeds_info = super().process_tokens(tokens, device) - comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + embeds, _, _, _ = super().process_tokens(tokens, device) return embeds class LuminaModel(sd1_clip.SD1ClipModel): diff --git a/comfy/text_encoders/qwen35.py b/comfy/text_encoders/qwen35.py index ce9b07464..d8ed9cd32 100644 --- a/comfy/text_encoders/qwen35.py +++ b/comfy/text_encoders/qwen35.py @@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_): nn.Module.__init__(self) self.config = config self.vocab_size = config.vocab_size - self.normalize_in = False - self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.layers = nn.ModuleList([ Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops) diff --git a/comfy/utils.py b/comfy/utils.py index 78c491b98..7b7faad3a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res -def normalize_image_embeddings(embeds, embeds_info, scale_factor): - """Normalize image embeddings to match text embedding scale""" - for info in embeds_info: - if info.get("type") == "image": - start_idx = info["index"] - end_idx = start_idx + info["size"] - embeds[:, start_idx:end_idx, :] /= scale_factor diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index 9f6918315..adb5a3144 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -5,12 +5,95 @@ This module handles capability negotiation between frontend and backend, allowing graceful protocol evolution while maintaining backward compatibility. """ -from typing import Any +import logging +from typing import Any, TypedDict from comfy.cli_args import args + +class FeatureFlagInfo(TypedDict): + type: str + default: Any + description: str + + +# Registry of known CLI-settable feature flags. +# Launchers can query this via --list-feature-flags to discover valid flags. +CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = { + "show_signin_button": { + "type": "bool", + "default": False, + "description": "Show the sign-in button in the frontend even when not signed in", + }, +} + + +def _coerce_bool(v: str) -> bool: + """Strict bool coercion: only 'true'/'false' (case-insensitive). + + Anything else raises ValueError so the caller can warn and drop the flag, + rather than silently treating typos like 'ture' or 'yes' as False. + """ + lower = v.lower() + if lower == "true": + return True + if lower == "false": + return False + raise ValueError(f"expected 'true' or 'false', got {v!r}") + + +_COERCE_FNS: dict[str, Any] = { + "bool": _coerce_bool, + "int": lambda v: int(v), + "float": lambda v: float(v), +} + + +def _coerce_flag_value(key: str, raw_value: str) -> Any: + """Coerce a raw string value using the registry type, or keep as string. + + Returns the raw string if the key is unregistered or the type is unknown. + Raises ValueError/TypeError if the key is registered with a known type but + the value cannot be coerced; callers are expected to warn and drop the flag. + """ + info = CLI_FEATURE_FLAG_REGISTRY.get(key) + if info is None: + return raw_value + coerce = _COERCE_FNS.get(info["type"]) + if coerce is None: + return raw_value + return coerce(raw_value) + + +def _parse_cli_feature_flags() -> dict[str, Any]: + """Parse --feature-flag key=value pairs from CLI args into a dict. + + Items without '=' default to the value 'true' (bare flag form). + Flags whose value cannot be coerced to the registered type are dropped + with a warning, so a typo like '--feature-flag some_bool=ture' does not + silently take effect as the wrong value. + """ + result: dict[str, Any] = {} + for item in getattr(args, "feature_flag", []): + key, sep, raw_value = item.partition("=") + key = key.strip() + if not key: + continue + if not sep: + raw_value = "true" + try: + result[key] = _coerce_flag_value(key, raw_value.strip()) + except (ValueError, TypeError) as e: + info = CLI_FEATURE_FLAG_REGISTRY.get(key, {}) + logging.warning( + "Could not coerce --feature-flag %s=%r to %s (%s); dropping flag.", + key, raw_value.strip(), info.get("type", "?"), e, + ) + return result + + # Default server capabilities -SERVER_FEATURE_FLAGS: dict[str, Any] = { +_CORE_FEATURE_FLAGS: dict[str, Any] = { "supports_preview_metadata": True, "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "extension": {"manager": {"supports_v4": True}}, @@ -18,6 +101,11 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = { "assets": args.enable_assets, } +# CLI-provided flags cannot overwrite core flags +_cli_flags = {k: v for k, v in _parse_cli_feature_flags().items() if k not in _CORE_FEATURE_FLAGS} + +SERVER_FEATURE_FLAGS: dict[str, Any] = {**_CORE_FEATURE_FLAGS, **_cli_flags} + def get_connection_feature( sockets_metadata: dict[str, dict[str, Any]], diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4942ed46c..e50266bc5 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -395,7 +395,6 @@ class Combo(ComfyTypeIO): @comfytype(io_type="COMBO") class MultiCombo(ComfyTypeI): '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' - # TODO: something is wrong with the serialization, frontend does not recognize it as multiselect Type = list[str] class Input(Combo.Input): def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, @@ -408,12 +407,14 @@ class MultiCombo(ComfyTypeI): self.default: list[str] def as_dict(self): - to_return = super().as_dict() | prune_dict({ - "multi_select": self.multiselect, - "placeholder": self.placeholder, - "chip": self.chip, + # Frontend expects `multi_select` to be an object config (not a boolean). + # Keep top-level `multiselect` from Combo.Input for backwards compatibility. + return super().as_dict() | prune_dict({ + "multi_select": prune_dict({ + "placeholder": self.placeholder, + "chip": self.chip, + }), }) - return to_return @comfytype(io_type="IMAGE") class Image(ComfyTypeIO): diff --git a/comfy_api_nodes/apis/luma.py b/comfy_api_nodes/apis/luma.py index 632c4ab96..8c6db2022 100644 --- a/comfy_api_nodes/apis/luma.py +++ b/comfy_api_nodes/apis/luma.py @@ -1,15 +1,12 @@ from __future__ import annotations - -import torch - from enum import Enum from typing import Optional, Union +import torch from pydantic import BaseModel, Field, confloat - class LumaIO: LUMA_REF = "LUMA_REF" LUMA_CONCEPTS = "LUMA_CONCEPTS" @@ -183,13 +180,13 @@ class LumaAssets(BaseModel): class LumaImageRef(BaseModel): - '''Used for image gen''' + """Used for image gen""" url: str = Field(..., description='The URL of the image reference') weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') class LumaImageReference(BaseModel): - '''Used for video gen''' + """Used for video gen""" type: Optional[str] = Field('image', description='Input type, defaults to image') url: str = Field(..., description='The URL of the image') @@ -251,3 +248,32 @@ class LumaGeneration(BaseModel): assets: Optional[LumaAssets] = Field(None, description='The assets of the generation') model: str = Field(..., description='The model used for the generation') request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation") + + +class Luma2ImageRef(BaseModel): + url: str | None = None + data: str | None = None + media_type: str | None = None + + +class Luma2GenerationRequest(BaseModel): + prompt: str = Field(..., min_length=1, max_length=6000) + model: str | None = None + type: str | None = None + aspect_ratio: str | None = None + style: str | None = None + output_format: str | None = None + web_search: bool | None = None + image_ref: list[Luma2ImageRef] | None = None + source: Luma2ImageRef | None = None + + +class Luma2Generation(BaseModel): + id: str | None = None + type: str | None = None + state: str | None = None + model: str | None = None + created_at: str | None = None + output: list[LumaImageReference] | None = None + failure_reason: str | None = None + failure_code: str | None = None diff --git a/comfy_api_nodes/apis/openai.py b/comfy_api_nodes/apis/openai.py index b85ef252b..bee75d639 100644 --- a/comfy_api_nodes/apis/openai.py +++ b/comfy_api_nodes/apis/openai.py @@ -56,14 +56,14 @@ class ModelResponseProperties(BaseModel): instructions: str | None = Field(None) max_output_tokens: int | None = Field(None) model: str | None = Field(None) - temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0) + temperature: float | None = Field(None, description="Controls randomness in the response", ge=0.0, le=2.0) top_p: float | None = Field( - 1, + None, description="Controls diversity of the response via nucleus sampling", ge=0.0, le=1.0, ) - truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") + truncation: str | None = Field(None, description="Allowed values: 'auto' or 'disabled'") class ResponseProperties(BaseModel): diff --git a/comfy_api_nodes/apis/topaz.py b/comfy_api_nodes/apis/topaz.py index a9e6235a7..f91980e3d 100644 --- a/comfy_api_nodes/apis/topaz.py +++ b/comfy_api_nodes/apis/topaz.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional from pydantic import BaseModel, Field @@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel): grain: Optional[float] = Field(None, description="Grain after AI model processing") grainSize: Optional[float] = Field(None, description="Size of generated grain") recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video") - creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only") + creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.") isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only") + prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)") + sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness") + realism: float | None = Field(None, description="ast-2 realism control") class OutputInformationVideo(BaseModel): @@ -90,7 +93,7 @@ class Overrides(BaseModel): class CreateVideoRequest(BaseModel): source: CreateVideoRequestSource = Field(...) - filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) + filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...) output: OutputInformationVideo = Field(...) overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 9ed6cd299..d92a7c382 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -1,10 +1,11 @@ -from typing import Optional - import torch from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.luma import ( + Luma2Generation, + Luma2GenerationRequest, + Luma2ImageRef, LumaAspectRatio, LumaCharacterRef, LumaConceptChain, @@ -30,6 +31,7 @@ from comfy_api_nodes.util import ( download_url_to_video_output, poll_op, sync_op, + upload_image_to_comfyapi, upload_images_to_comfyapi, validate_string, ) @@ -212,9 +214,9 @@ class LumaImageGenerationNode(IO.ComfyNode): aspect_ratio: str, seed, style_image_weight: float, - image_luma_ref: Optional[LumaReferenceChain] = None, - style_image: Optional[torch.Tensor] = None, - character_image: Optional[torch.Tensor] = None, + image_luma_ref: LumaReferenceChain | None = None, + style_image: torch.Tensor | None = None, + character_image: torch.Tensor | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=3) # handle image_luma_ref @@ -434,7 +436,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): duration: str, loop: bool, seed, - luma_concepts: Optional[LumaConceptChain] = None, + luma_concepts: LumaConceptChain | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, min_length=3) duration = duration if model != LumaVideoModel.ray_1_6 else None @@ -533,7 +535,6 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): ], is_api_node=True, price_badge=PRICE_BADGE_VIDEO, - ) @classmethod @@ -644,6 +645,293 @@ PRICE_BADGE_VIDEO = IO.PriceBadge( ) +def _luma2_uni1_common_inputs(max_image_refs: int) -> list: + return [ + IO.Combo.Input( + "style", + options=["auto", "manga"], + default="auto", + tooltip="Style preset. 'auto' picks based on the prompt; " + "'manga' applies a manga/anime aesthetic and requires a portrait " + "aspect ratio (2:3, 9:16, 1:2, 1:3).", + ), + IO.Boolean.Input( + "web_search", + default=False, + tooltip="Search the web for visual references before generating.", + ), + IO.Autogrow.Input( + "image_ref", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("image"), + names=[f"image_{i}" for i in range(1, max_image_refs + 1)], + min=0, + ), + optional=True, + tooltip=f"Up to {max_image_refs} reference images for style/content guidance.", + ), + ] + + +async def _luma2_upload_image_refs( + cls: type[IO.ComfyNode], + refs: dict | None, + max_count: int, +) -> list[Luma2ImageRef] | None: + if not refs: + return None + out: list[Luma2ImageRef] = [] + for key in refs: + url = await upload_image_to_comfyapi(cls, refs[key]) + out.append(Luma2ImageRef(url=url)) + if len(out) > max_count: + raise ValueError(f"Maximum {max_count} reference images are allowed.") + return out or None + + +async def _luma2_submit_and_poll( + cls: type[IO.ComfyNode], + request: Luma2GenerationRequest, +) -> Input.Image: + initial = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma_2/generations", method="POST"), + response_model=Luma2Generation, + data=request, + ) + if not initial.id: + raise RuntimeError("Luma 2 API did not return a generation id.") + final = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"), + response_model=Luma2Generation, + status_extractor=lambda r: r.state, + progress_extractor=lambda r: None, + ) + if not final.output: + msg = final.failure_reason or "no output returned" + raise RuntimeError(f"Luma 2 generation failed: {msg}") + url = final.output[0].url + if not url: + raise RuntimeError("Luma 2 generation completed without an output URL.") + return await download_url_to_image_tensor(url) + + +class LumaImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaImageNode2", + display_name="Luma UNI-1 Image", + category="api node/image/Luma", + description="Generate images from text using the Luma UNI-1 model.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired image. 1–6000 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "uni-1", + [ + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "3:1", + "2:1", + "16:9", + "3:2", + "1:1", + "2:3", + "9:16", + "1:2", + "1:3", + ], + default="auto", + tooltip="Output image aspect ratio. 'auto' lets " + "the model pick based on the prompt.", + ), + *_luma2_uni1_common_inputs(max_image_refs=9), + ], + ), + IO.DynamicCombo.Option( + "uni-1-max", + [ + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "3:1", + "2:1", + "16:9", + "3:2", + "1:1", + "2:3", + "9:16", + "1:2", + "1:3", + ], + default="auto", + tooltip="Output image aspect ratio. 'auto' lets " + "the model pick based on the prompt.", + ), + *_luma2_uni1_common_inputs(max_image_refs=9), + ], + ), + ], + tooltip="Model to use for generation.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"], input_groups=["model.image_ref"]), + expr=""" + ( + $m := widgets.model; + $refs := $lookup(inputGroups, "model.image_ref"); + $base := $m = "uni-1-max" ? 0.1 : 0.0404; + {"type":"usd","usd": $round($base + 0.003 * $refs, 4)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=6000) + aspect_ratio = model["aspect_ratio"] + style = model["style"] + allowed_manga_ratios = {"2:3", "9:16", "1:2", "1:3"} + if style == "manga" and aspect_ratio != "auto" and aspect_ratio not in allowed_manga_ratios: + raise ValueError( + f"'manga' style requires a portrait aspect ratio " + f"({', '.join(sorted(allowed_manga_ratios))}) or 'auto'; got '{aspect_ratio}'." + ) + request = Luma2GenerationRequest( + prompt=prompt, + model=model["model"], + type="image", + aspect_ratio=aspect_ratio if aspect_ratio != "auto" else None, + style=style if style != "auto" else None, + output_format="png", + web_search=model["web_search"], + image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9), + ) + return IO.NodeOutput(await _luma2_submit_and_poll(cls, request)) + + +class LumaImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="LumaImageEditNode2", + display_name="Luma UNI-1 Image Edit", + category="api node/image/Luma", + description="Edit an existing image with a text prompt using the Luma UNI-1 model.", + inputs=[ + IO.Image.Input( + "source", + tooltip="Source image to edit.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Description of the desired edit. 1–6000 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "uni-1", + _luma2_uni1_common_inputs(max_image_refs=8), + ), + IO.DynamicCombo.Option( + "uni-1-max", + _luma2_uni1_common_inputs(max_image_refs=8), + ), + ], + tooltip="Model to use for editing.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"], input_groups=["model.image_ref"]), + expr=""" + ( + $m := widgets.model; + $refs := $lookup(inputGroups, "model.image_ref"); + $base := $m = "uni-1-max" ? 0.103 : 0.0434; + {"type":"usd","usd": $round($base + 0.003 * $refs, 4)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + source: Input.Image, + prompt: str, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=6000) + request = Luma2GenerationRequest( + prompt=prompt, + model=model["model"], + type="image_edit", + source=Luma2ImageRef(url=await upload_image_to_comfyapi(cls, source)), + style=model["style"] if model["style"] != "auto" else None, + output_format="png", + web_search=model["web_search"], + image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8), + ) + return IO.NodeOutput(await _luma2_submit_and_poll(cls, request)) + + class LumaExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -654,6 +942,8 @@ class LumaExtension(ComfyExtension): LumaImageToVideoGenerationNode, LumaReferenceNode, LumaConceptsNode, + LumaImageNode, + LumaImageEditNode, ] diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 21fe470ce..daed495da 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -39,16 +39,18 @@ STARTING_POINT_ID_PATTERN = r"" class SupportedOpenAIModel(str, Enum): - o4_mini = "o4-mini" - o1 = "o1" - o3 = "o3" - o1_pro = "o1-pro" - gpt_4_1 = "gpt-4.1" - gpt_4_1_mini = "gpt-4.1-mini" - gpt_4_1_nano = "gpt-4.1-nano" + gpt_5_5_pro = "gpt-5.5-pro" + gpt_5_5 = "gpt-5.5" gpt_5 = "gpt-5" gpt_5_mini = "gpt-5-mini" gpt_5_nano = "gpt-5-nano" + gpt_4_1 = "gpt-4.1" + gpt_4_1_mini = "gpt-4.1-mini" + gpt_4_1_nano = "gpt-4.1-nano" + o4_mini = "o4-mini" + o3 = "o3" + o1_pro = "o1-pro" + o1 = "o1" async def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: @@ -739,6 +741,16 @@ class OpenAIChatNode(IO.ComfyNode): "usd": [0.002, 0.008], "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } } + : $contains($m, "gpt-5.5-pro") ? { + "type": "list_usd", + "usd": [0.03, 0.18], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5.5") ? { + "type": "list_usd", + "usd": [0.005, 0.03], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } : $contains($m, "gpt-5-nano") ? { "type": "list_usd", "usd": [0.00005, 0.0004], diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index 4d9075dcf..c1d485188 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -33,7 +33,7 @@ class OpenAIVideoSora2(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="OpenAIVideoSora2", - display_name="OpenAI Sora - Video (Deprecated)", + display_name="OpenAI Sora - Video (DEPRECATED)", category="api node/video/Sora", description=( "OpenAI video and audio generation.\n\n" diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index fe3666ec9..e79c16d3c 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -36,11 +36,15 @@ from comfy_api_nodes.util import ( ) UPSCALER_MODELS_MAP = { + "Astra 2": "ast-2", "Starlight (Astra) Fast": "slf-1", "Starlight (Astra) Creative": "slc-1", "Starlight Precise 2.5": "slp-2.5", } +AST2_MAX_FRAMES = 9000 +AST2_MAX_FRAMES_WITH_PROMPT = 450 + class TopazImageEnhance(IO.ComfyNode): @classmethod @@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="TopazVideoEnhance", - display_name="Topaz Video Enhance", + display_name="Topaz Video Enhance (Legacy)", category="api node/video/Topaz", description="Breathe new life into video with powerful upscaling and recovery technology.", inputs=[ IO.Video.Input("video"), IO.Boolean.Input("upscaler_enabled", default=True), - IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), + IO.Combo.Input( + "upscaler_model", + options=[ + "Starlight (Astra) Fast", + "Starlight (Astra) Creative", + "Starlight Precise 2.5", + ], + ), IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), IO.Combo.Input( "upscaler_creativity", @@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -457,12 +469,357 @@ class TopazVideoEnhance(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) +class TopazVideoEnhanceV2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TopazVideoEnhanceV2", + display_name="Topaz Video Enhance", + category="api node/video/Topaz", + description="Breathe new life into video with powerful upscaling and recovery technology.", + inputs=[ + IO.Video.Input("video"), + IO.DynamicCombo.Input( + "upscaler_model", + options=[ + IO.DynamicCombo.Option( + "Astra 2", + [ + IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), + IO.Float.Input( + "creativity", + default=0.5, + min=0.0, + max=1.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Creative strength of the upscale.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Optional descriptive (not instructive) scene prompt." + f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.", + ), + IO.Float.Input( + "sharp", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Pre-enhance sharpness: " + "0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.", + advanced=True, + ), + IO.Float.Input( + "realism", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Pulls output toward photographic realism." + "Leave at 0 for the model default.", + advanced=True, + ), + ], + ), + IO.DynamicCombo.Option( + "Starlight (Astra) Fast", + [IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),], + ), + IO.DynamicCombo.Option( + "Starlight (Astra) Creative", + [ + IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), + IO.Combo.Input( + "creativity", + options=["low", "middle", "high"], + default="low", + tooltip="Creative strength of the upscale.", + ), + ], + ), + IO.DynamicCombo.Option( + "Starlight Precise 2.5", + [IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])], + ), + IO.DynamicCombo.Option("Disabled", []), + ], + ), + IO.DynamicCombo.Input( + "interpolation_model", + options=[ + IO.DynamicCombo.Option("Disabled", []), + IO.DynamicCombo.Option( + "apo-8", + [ + IO.Int.Input( + "interpolation_frame_rate", + default=60, + min=15, + max=240, + display_mode=IO.NumberDisplay.number, + tooltip="Output frame rate.", + ), + IO.Int.Input( + "interpolation_slowmo", + default=1, + min=1, + max=16, + display_mode=IO.NumberDisplay.number, + tooltip="Slow-motion factor applied to the input video. " + "For example, 2 makes the output twice as slow and doubles the duration.", + advanced=True, + ), + IO.Boolean.Input( + "interpolation_duplicate", + default=False, + tooltip="Analyze the input for duplicate frames and remove them.", + advanced=True, + ), + IO.Float.Input( + "interpolation_duplicate_threshold", + default=0.01, + min=0.001, + max=0.1, + step=0.001, + display_mode=IO.NumberDisplay.number, + tooltip="Detection sensitivity for duplicate frames.", + advanced=True, + ), + ], + ), + ], + ), + IO.Combo.Input( + "dynamic_compression_level", + options=["Low", "Mid", "High"], + default="Low", + tooltip="CQP level.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=[ + "upscaler_model", + "upscaler_model.upscaler_resolution", + "interpolation_model", + ]), + expr=""" + ( + $model := $lookup(widgets, "upscaler_model"); + $res := $lookup(widgets, "upscaler_model.upscaler_resolution"); + $interp := $lookup(widgets, "interpolation_model"); + $is4k := $contains($res, "4k"); + $hasInterp := $interp != "disabled"; + $rates := { + "starlight (astra) fast": {"hd": 0.43, "uhd": 0.85}, + "starlight precise 2.5": {"hd": 0.70, "uhd": 1.54}, + "astra 2": {"hd": 1.72, "uhd": 2.85}, + "starlight (astra) creative": {"hd": 2.25, "uhd": 3.99} + }; + $surcharge := $is4k ? 0.28 : 0.14; + $entry := $lookup($rates, $model); + $base := $is4k ? $entry.uhd : $entry.hd; + $hi := $base + ($hasInterp ? $surcharge : 0); + $model = "disabled" + ? {"type":"text","text":"Interpolation only"} + : ($hasInterp + ? {"type":"text","text":"~" & $string($base) & "–" & $string($hi) & " credits/src frame"} + : {"type":"text","text":"~" & $string($base) & " credits/src frame"}) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + upscaler_model: dict, + interpolation_model: dict, + dynamic_compression_level: str = "Low", + ) -> IO.NodeOutput: + upscaler_choice = upscaler_model["upscaler_model"] + interpolation_choice = interpolation_model["interpolation_model"] + if upscaler_choice == "Disabled" and interpolation_choice == "Disabled": + raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") + validate_container_format_is_mp4(video) + src_width, src_height = video.get_dimensions() + src_frame_rate = int(video.get_frame_rate()) + duration_sec = video.get_duration() + src_video_stream = video.get_stream_source() + target_width = src_width + target_height = src_height + target_frame_rate = src_frame_rate + filters = [] + if upscaler_choice != "Disabled": + if "1080p" in upscaler_model["upscaler_resolution"]: + target_pixel_p = 1080 + max_long_side = 1920 + else: + target_pixel_p = 2160 + max_long_side = 3840 + ar = src_width / src_height + if src_width >= src_height: + # Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width + target_height = target_pixel_p + target_width = int(target_height * ar) + # Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs) + if target_width > max_long_side: + target_width = max_long_side + target_height = int(target_width / ar) + else: + # Portrait; Attempt to set width to target (e.g., 2160), calculate height + target_width = target_pixel_p + target_height = int(target_width / ar) + # Check if height exceeds standard bounds + if target_height > max_long_side: + target_height = max_long_side + target_width = int(target_height * ar) + if target_width % 2 != 0: + target_width += 1 + if target_height % 2 != 0: + target_height += 1 + model_id = UPSCALER_MODELS_MAP[upscaler_choice] + if model_id == "slc-1": + filters.append( + VideoEnhancementFilter( + model=model_id, + creativity=upscaler_model["creativity"], + isOptimizedMode=True, + ) + ) + elif model_id == "ast-2": + n_frames = video.get_frame_count() + ast2_prompt = (upscaler_model["prompt"] or "").strip() + if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT: + raise ValueError( + f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames " + f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip." + ) + if n_frames > AST2_MAX_FRAMES: + raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.") + realism = upscaler_model["realism"] + filters.append( + VideoEnhancementFilter( + model=model_id, + creativity=upscaler_model["creativity"], + prompt=(ast2_prompt or None), + sharp=upscaler_model["sharp"], + realism=(realism if realism > 0 else None), + ) + ) + else: + filters.append(VideoEnhancementFilter(model=model_id)) + if interpolation_choice != "Disabled": + target_frame_rate = interpolation_model["interpolation_frame_rate"] + filters.append( + VideoFrameInterpolationFilter( + model=interpolation_choice, + slowmo=interpolation_model["interpolation_slowmo"], + fps=interpolation_model["interpolation_frame_rate"], + duplicate=interpolation_model["interpolation_duplicate"], + duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"], + ), + ) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/topaz/video/", method="POST"), + response_model=CreateVideoResponse, + data=CreateVideoRequest( + source=CreateVideoRequestSource( + container="mp4", + size=get_fs_object_size(src_video_stream), + duration=int(duration_sec), + frameCount=video.get_frame_count(), + frameRate=src_frame_rate, + resolution=Resolution(width=src_width, height=src_height), + ), + filters=filters, + output=OutputInformationVideo( + resolution=Resolution(width=target_width, height=target_height), + frameRate=target_frame_rate, + audioCodec="AAC", + audioTransfer="Copy", + dynamicCompressionLevel=dynamic_compression_level, + ), + ), + wait_label="Creating task", + final_label_on_success="Task created", + ) + upload_res = await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/topaz/video/{initial_res.requestId}/accept", + method="PATCH", + ), + response_model=VideoAcceptResponse, + wait_label="Preparing upload", + final_label_on_success="Upload started", + ) + if len(upload_res.urls) > 1: + raise NotImplementedError( + "Large files are not currently supported. Please open an issue in the ComfyUI repository." + ) + async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session: + if isinstance(src_video_stream, BytesIO): + src_video_stream.seek(0) + async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res: + upload_etag = res.headers["Etag"] + else: + with builtins.open(src_video_stream, "rb") as video_file: + async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res: + upload_etag = res.headers["Etag"] + await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload", + method="PATCH", + ), + response_model=VideoCompleteUploadResponse, + data=VideoCompleteUploadRequest( + uploadResults=[ + VideoCompleteUploadRequestPart( + partNum=1, + eTag=upload_etag, + ), + ], + ), + wait_label="Finalizing upload", + final_label_on_success="Upload completed", + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"), + response_model=VideoStatusResponse, + status_extractor=lambda x: x.status, + progress_extractor=lambda x: getattr(x, "progress", 0), + price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), + poll_interval=10.0, + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) + + class TopazExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ TopazImageEnhance, TopazVideoEnhance, + TopazVideoEnhanceV2, ] diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index a0b8d35e1..8e1ba91ba 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -19,6 +19,8 @@ from comfy import utils from comfy_api.latest import IO from server import PromptServer +from comfy.deploy_environment import get_deploy_environment + from . import request_logger from ._helpers import ( default_base_url, @@ -624,6 +626,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? payload_headers.update(get_auth_header(cfg.node_cls)) + payload_headers["Comfy-Env"] = get_deploy_environment() if cfg.endpoint.headers: payload_headers.update(cfg.endpoint.headers) diff --git a/comfy_extras/frame_interpolation_models/film_net.py b/comfy_extras/frame_interpolation_models/film_net.py index cf4f6e1e1..36bc79dc3 100644 --- a/comfy_extras/frame_interpolation_models/film_net.py +++ b/comfy_extras/frame_interpolation_models/film_net.py @@ -199,6 +199,9 @@ class FILMNet(nn.Module): def get_dtype(self): return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype + def memory_used_forward(self, shape, dtype): + return 1700 * shape[1] * shape[2] * dtype.itemsize + def _build_warp_grids(self, H, W, device): """Pre-compute warp grids for all pyramid levels.""" if (H, W) in self._warp_grids: diff --git a/comfy_extras/frame_interpolation_models/ifnet.py b/comfy_extras/frame_interpolation_models/ifnet.py index 03cb34c50..ad6edbec9 100644 --- a/comfy_extras/frame_interpolation_models/ifnet.py +++ b/comfy_extras/frame_interpolation_models/ifnet.py @@ -74,6 +74,9 @@ class IFNet(nn.Module): def get_dtype(self): return self.encode.cnn0.weight.dtype + def memory_used_forward(self, shape, dtype): + return 300 * shape[1] * shape[2] * dtype.itemsize + def _build_warp_grids(self, H, W, device): if (H, W) in self._warp_grids: return diff --git a/comfy_extras/nodes_ar_video.py b/comfy_extras/nodes_ar_video.py new file mode 100644 index 000000000..09ee886fd --- /dev/null +++ b/comfy_extras/nodes_ar_video.py @@ -0,0 +1,84 @@ +""" +ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.). + - EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors + - SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop +""" + +import torch +from typing_extensions import override + +import comfy.model_management +import comfy.samplers +from comfy_api.latest import ComfyExtension, io + + +class EmptyARVideoLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyARVideoLatent", + category="latent/video", + inputs=[ + io.Int.Input("width", default=832, min=16, max=8192, step=16), + io.Int.Input("height", default=480, min=16, max=8192, step=16), + io.Int.Input("length", default=81, min=1, max=1024, step=4), + io.Int.Input("batch_size", default=1, min=1, max=64), + ], + outputs=[ + io.Latent.Output(display_name="LATENT"), + ], + ) + + @classmethod + def execute(cls, width, height, length, batch_size) -> io.NodeOutput: + lat_t = ((length - 1) // 4) + 1 + latent = torch.zeros( + [batch_size, 16, lat_t, height // 8, width // 8], + device=comfy.model_management.intermediate_device(), + ) + return io.NodeOutput({"samples": latent}) + + +class SamplerARVideo(io.ComfyNode): + """Sampler for autoregressive video models (Causal Forcing, Self-Forcing). + + All AR-loop parameters are owned by this node so they live in the workflow. + Add new widgets here as the AR sampler grows new options. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerARVideo", + display_name="Sampler AR Video", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input( + "num_frame_per_block", + default=1, min=1, max=64, + tooltip="Frames per autoregressive block. 1 = framewise, " + "3 = chunkwise. Must match the checkpoint's training mode.", + ), + ], + outputs=[io.Sampler.Output()], + ) + + @classmethod + def execute(cls, num_frame_per_block) -> io.NodeOutput: + extra_options = { + "num_frame_per_block": num_frame_per_block, + } + return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options)) + + +class ARVideoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + EmptyARVideoLatent, + SamplerARVideo, + ] + + +async def comfy_entrypoint() -> ARVideoExtension: + return ARVideoExtension() diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 3bc9fccb3..5b4423734 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode): @classmethod def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: - batch_size = min(len(image), len(alpha)) - out_images = [] - + batch_size = max(len(image), len(alpha)) alpha = 1.0 - resize_mask(alpha, image.shape[1:]) - for i in range(batch_size): - out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) - - return io.NodeOutput(torch.stack(out_images)) + alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size) + image = comfy.utils.repeat_to_batch_size(image, batch_size) + return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1)) class CompositingExtension(ComfyExtension): diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index a3b00d36e..9dd34cfb8 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -37,7 +37,7 @@ class FrameInterpolationModelLoader(io.ComfyNode): model = cls._detect_and_load(sd) dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32 model.eval().to(dtype) - patcher = comfy.model_patcher.ModelPatcher( + patcher = comfy.model_patcher.CoreModelPatcher( model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), @@ -78,7 +78,7 @@ class FrameInterpolate(io.ComfyNode): return io.Schema( node_id="FrameInterpolate", display_name="Frame Interpolate", - category="image/video", + category="video", search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], inputs=[ FrameInterpolationModel.Input("interp_model"), @@ -98,16 +98,13 @@ class FrameInterpolate(io.ComfyNode): if num_frames < 2 or multiplier < 2: return io.NodeOutput(images) - model_management.load_model_gpu(interp_model) device = interp_model.load_device dtype = interp_model.model_dtype() inference_model = interp_model.model - - # Free VRAM for inference activations (model weights + ~20x a single frame's worth) - H, W = images.shape[1], images.shape[2] - activation_mem = H * W * 3 * images.element_size() * 20 - model_management.free_memory(activation_mem, device) + activation_mem = inference_model.memory_used_forward(images.shape, dtype) + model_management.load_models_gpu([interp_model], memory_required=activation_mem) align = getattr(inference_model, "pad_align", 1) + H, W = images.shape[1], images.shape[2] # Prepare a single padded frame on device for determining output dimensions def prepare_frame(idx): diff --git a/comfy_extras/nodes_image_compare.py b/comfy_extras/nodes_image_compare.py index 3d943be67..58af9ae82 100644 --- a/comfy_extras/nodes_image_compare.py +++ b/comfy_extras/nodes_image_compare.py @@ -11,7 +11,7 @@ class ImageCompare(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageCompare", - display_name="Image Compare", + display_name="Compare Images", description="Compares two images side by side with a slider.", category="image", essentials_category="Image Tools", diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index a77f0641f..1ac740d1d 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -24,7 +24,7 @@ class ImageCrop(IO.ComfyNode): return IO.Schema( node_id="ImageCrop", search_aliases=["trim"], - display_name="Image Crop (Deprecated)", + display_name="Crop Image (DEPRECATED)", category="image/transform", is_deprecated=True, essentials_category="Image Tools", @@ -56,7 +56,7 @@ class ImageCropV2(IO.ComfyNode): return IO.Schema( node_id="ImageCropV2", search_aliases=["trim"], - display_name="Image Crop", + display_name="Crop Image", category="image/transform", essentials_category="Image Tools", has_intermediate_output=True, @@ -109,6 +109,7 @@ class RepeatImageBatch(IO.ComfyNode): return IO.Schema( node_id="RepeatImageBatch", search_aliases=["duplicate image", "clone image"], + display_name="Repeat Image Batch", category="image/batch", inputs=[ IO.Image.Input("image"), @@ -131,6 +132,7 @@ class ImageFromBatch(IO.ComfyNode): return IO.Schema( node_id="ImageFromBatch", search_aliases=["select image", "pick from batch", "extract image"], + display_name="Get Image from Batch", category="image/batch", inputs=[ IO.Image.Input("image"), @@ -157,7 +159,8 @@ class ImageAddNoise(IO.ComfyNode): return IO.Schema( node_id="ImageAddNoise", search_aliases=["film grain"], - category="image", + display_name="Add Noise to Image", + category="image/postprocessing", inputs=[ IO.Image.Input("image"), IO.Int.Input( @@ -259,7 +262,7 @@ class ImageStitch(IO.ComfyNode): return IO.Schema( node_id="ImageStitch", search_aliases=["combine images", "join images", "concatenate images", "side by side"], - display_name="Image Stitch", + display_name="Stitch Images", description="Stitches image2 to image1 in the specified direction.\n" "If image2 is not provided, returns image1 unchanged.\n" "Optional spacing can be added between images.", @@ -434,6 +437,7 @@ class ResizeAndPadImage(IO.ComfyNode): return IO.Schema( node_id="ResizeAndPadImage", search_aliases=["fit to size"], + display_name="Resize And Pad Image", category="image/transform", inputs=[ IO.Image.Input("image"), @@ -485,6 +489,7 @@ class SaveSVGNode(IO.ComfyNode): return IO.Schema( node_id="SaveSVGNode", search_aliases=["export vector", "save vector graphics"], + display_name="Save SVG", description="Save SVG files on disk.", category="image/save", inputs=[ @@ -591,7 +596,7 @@ class ImageRotate(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageRotate", - display_name="Image Rotate", + display_name="Rotate Image", search_aliases=["turn", "flip orientation"], category="image/transform", essentials_category="Image Tools", @@ -624,6 +629,7 @@ class ImageFlip(IO.ComfyNode): return IO.Schema( node_id="ImageFlip", search_aliases=["mirror", "reflect"], + display_name="Flip Image", category="image/transform", inputs=[ IO.Image.Input("image"), @@ -650,6 +656,7 @@ class ImageScaleToMaxDimension(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageScaleToMaxDimension", + display_name="Scale Image to Max Dimension", category="image/upscaling", inputs=[ IO.Image.Input("image"), @@ -709,7 +716,7 @@ class SplitImageToTileList(IO.ComfyNode): def get_grid_coords(width, height, tile_width, tile_height, overlap): coords = [] stride_x = round(max(tile_width * 0.25, tile_width - overlap)) - stride_y = round(max(tile_width * 0.25, tile_height - overlap)) + stride_y = round(max(tile_height * 0.25, tile_height - overlap)) y = 0 while y < height: diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 3ec635c75..2c1f63afb 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -147,7 +147,6 @@ class LTXVEmptyLatentAudio(io.ComfyNode): z_channels = audio_vae.latent_channels audio_freq = audio_vae.first_stage_model.latent_frequency_bins - sampling_rate = int(audio_vae.first_stage_model.sample_rate) num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate) @@ -159,7 +158,6 @@ class LTXVEmptyLatentAudio(io.ComfyNode): return io.NodeOutput( { "samples": audio_latents, - "sample_rate": sampling_rate, "type": "audio", } ) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 8ca947718..43a933dac 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -80,7 +80,8 @@ class ImageCompositeMasked(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageCompositeMasked", - search_aliases=["paste image", "overlay", "layer"], + search_aliases=["overlay", "layer", "paste image", "images composition"], + display_name="Image Composite Masked", category="image", inputs=[ IO.Image.Input("destination"), @@ -201,6 +202,7 @@ class InvertMask(IO.ComfyNode): return IO.Schema( node_id="InvertMask", search_aliases=["reverse mask", "flip mask"], + display_name="Invert Mask", category="mask", inputs=[ IO.Mask.Input("mask"), @@ -222,6 +224,7 @@ class CropMask(IO.ComfyNode): return IO.Schema( node_id="CropMask", search_aliases=["cut mask", "extract mask region", "mask slice"], + display_name="Crop Mask", category="mask", inputs=[ IO.Mask.Input("mask"), @@ -247,7 +250,8 @@ class MaskComposite(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="MaskComposite", - search_aliases=["combine masks", "blend masks", "layer masks"], + search_aliases=["combine masks", "blend masks", "layer masks", "masks composition"], + display_name="Combine Masks", category="mask", inputs=[ IO.Mask.Input("destination"), @@ -298,6 +302,7 @@ class FeatherMask(IO.ComfyNode): return IO.Schema( node_id="FeatherMask", search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"], + display_name="Feather Mask", category="mask", inputs=[ IO.Mask.Input("mask"), diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py index 4ab2fb7e8..c01b9436d 100644 --- a/comfy_extras/nodes_morphology.py +++ b/comfy_extras/nodes_morphology.py @@ -59,7 +59,8 @@ class ImageRGBToYUV(io.ComfyNode): return io.Schema( node_id="ImageRGBToYUV", search_aliases=["color space conversion"], - category="image/batch", + display_name="Image RGB to YUV", + category="image/color", inputs=[ io.Image.Input("image"), ], @@ -81,7 +82,8 @@ class ImageYUVToRGB(io.ComfyNode): return io.Schema( node_id="ImageYUVToRGB", search_aliases=["color space conversion"], - category="image/batch", + display_name="Image YUV to RGB", + category="image/color", inputs=[ io.Image.Input("Y"), io.Image.Input("U"), diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index c932b747a..d938a2035 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -20,7 +20,8 @@ class Blend(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageBlend", - display_name="Image Blend", + search_aliases=["mix images"], + display_name="Blend Images", category="image/postprocessing", essentials_category="Image Tools", inputs=[ @@ -224,6 +225,7 @@ class ImageScaleToTotalPixels(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageScaleToTotalPixels", + display_name="Scale Image to Total Pixels", category="image/upscaling", inputs=[ io.Image.Input("image"), @@ -568,7 +570,7 @@ class BatchImagesNode(io.ComfyNode): return io.Schema( node_id="BatchImagesNode", display_name="Batch Images", - category="image", + category="image/batch", essentials_category="Image Tools", search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], inputs=[ @@ -666,12 +668,13 @@ class ColorTransfer(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ColorTransfer", + display_name="Color Transfer", category="image/postprocessing", description="Match the colors of one image to another using various algorithms.", search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"], inputs=[ io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."), - io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"), + io.Image.Input("image_ref", tooltip="Reference image(s) to match colors to."), io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],), io.DynamicCombo.Input("source_stats", tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)", diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 9c2e98758..33373266b 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -9,7 +9,8 @@ class String(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PrimitiveString", - display_name="String", + search_aliases=["text", "string", "text box", "prompt"], + display_name="Text String", category="utils/primitive", inputs=[ io.String.Input("value"), @@ -27,7 +28,8 @@ class StringMultiline(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PrimitiveStringMultiline", - display_name="String (Multiline)", + search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"], + display_name="Text String (Multiline)", category="utils/primitive", essentials_category="Basics", inputs=[ @@ -49,7 +51,7 @@ class Int(io.ComfyNode): display_name="Int", category="utils/primitive", inputs=[ - io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True), + io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed), ], outputs=[io.Int.Output()], ) diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 604076c4e..925a40da8 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -10,9 +10,9 @@ class StringConcatenate(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringConcatenate", - display_name="Text Concatenate", - category="utils/string", - search_aliases=["Concatenate", "text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"], + search_aliases=["concatenate", "text concat", "join text", "merge text", "combine strings", "string concat", "append text", "combine text"], + display_name="Concatenate Text", + category="text", inputs=[ io.String.Input("string_a", multiline=True), io.String.Input("string_b", multiline=True), @@ -33,9 +33,9 @@ class StringSubstring(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringSubstring", - search_aliases=["Substring", "extract text", "text portion"], - display_name="Text Substring", - category="utils/string", + search_aliases=["substring", "extract text", "text portion"], + display_name="Substring", + category="text", inputs=[ io.String.Input("string", multiline=True), io.Int.Input("start"), @@ -58,7 +58,7 @@ class StringLength(io.ComfyNode): node_id="StringLength", search_aliases=["character count", "text size", "string length"], display_name="Text Length", - category="utils/string", + category="text", inputs=[ io.String.Input("string", multiline=True), ], @@ -77,9 +77,9 @@ class CaseConverter(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CaseConverter", - search_aliases=["Case Converter", "text case", "uppercase", "lowercase", "capitalize"], - display_name="Text Case Converter", - category="utils/string", + search_aliases=["case converter", "text case", "uppercase", "lowercase", "capitalize"], + display_name="Convert Text Case", + category="text", inputs=[ io.String.Input("string", multiline=True), io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]), @@ -110,9 +110,9 @@ class StringTrim(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringTrim", - search_aliases=["Trim", "clean whitespace", "remove whitespace", "strip"], - display_name="Text Trim", - category="utils/string", + search_aliases=["trim", "clean whitespace", "remove whitespace", "remove spaces","strip"], + display_name="Trim Text", + category="text", inputs=[ io.String.Input("string", multiline=True), io.Combo.Input("mode", options=["Both", "Left", "Right"]), @@ -141,9 +141,9 @@ class StringReplace(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringReplace", - search_aliases=["Replace", "find and replace", "substitute", "swap text"], - display_name="Text Replace", - category="utils/string", + search_aliases=["replace", "find and replace", "substitute", "swap text"], + display_name="Replace Text", + category="text", inputs=[ io.String.Input("string", multiline=True), io.String.Input("find", multiline=True), @@ -164,9 +164,9 @@ class StringContains(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringContains", - search_aliases=["Contains", "text includes", "string includes"], - display_name="Text Contains", - category="utils/string", + search_aliases=["contains", "text includes", "string includes"], + display_name="Contains Text", + category="text", inputs=[ io.String.Input("string", multiline=True), io.String.Input("substring", multiline=True), @@ -192,9 +192,9 @@ class StringCompare(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringCompare", - search_aliases=["Compare", "text match", "string equals", "starts with", "ends with"], - display_name="Text Compare", - category="utils/string", + search_aliases=["compare", "text match", "string equals", "starts with", "ends with"], + display_name="Compare Text", + category="text", inputs=[ io.String.Input("string_a", multiline=True), io.String.Input("string_b", multiline=True), @@ -228,9 +228,9 @@ class RegexMatch(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexMatch", - search_aliases=["Regex Match", "regex", "pattern match", "text contains", "string match"], - display_name="Text Match", - category="utils/string", + search_aliases=["regex match", "regex", "pattern match", "text contains", "string match"], + display_name="Match Text", + category="text", inputs=[ io.String.Input("string", multiline=True), io.String.Input("regex_pattern", multiline=True), @@ -269,9 +269,9 @@ class RegexExtract(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexExtract", - search_aliases=["Regex Extract", "regex", "pattern extract", "text parser", "parse text"], - display_name="Text Extract Substring", - category="utils/string", + search_aliases=["regex extract", "regex", "pattern extract", "text parser", "parse text"], + display_name="Extract Text", + category="text", inputs=[ io.String.Input("string", multiline=True), io.String.Input("regex_pattern", multiline=True), @@ -344,9 +344,9 @@ class RegexReplace(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexReplace", - search_aliases=["Regex Replace", "regex", "pattern replace", "regex replace", "substitution"], - display_name="Text Replace (Regex)", - category="utils/string", + search_aliases=["regex replace", "regex", "pattern replace", "substitution"], + display_name="Replace Text (Regex)", + category="text", description="Find and replace text using regex patterns.", inputs=[ io.String.Input("string", multiline=True), @@ -381,8 +381,8 @@ class JsonExtractString(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="JsonExtractString", - display_name="Extract String from JSON", - category="utils/string", + display_name="Extract Text from JSON", + category="text", search_aliases=["json", "extract json", "parse json", "json value", "read json"], inputs=[ io.String.Input("json_string", multiline=True), diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index 1f46d820f..1661a1011 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode): io.Clip.Input("clip"), io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""), io.Image.Input("image", optional=True), + io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."), + io.Audio.Input("audio", optional=True), io.Int.Input("max_length", default=256, min=1, max=2048), io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."), @@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput: - tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking) + tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio) # Get sampling parameters from dynamic combo do_sample = sampling_mode.get("sampling_mode") == "on" @@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode): seed=seed ) - generated_text = clip.decode(generated_ids, skip_special_tokens=True) + generated_text = clip.decode(generated_ids) + return io.NodeOutput(generated_text) @@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput: if image is None: formatted_prompt = f"system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}\nuser\nUser Raw Input Prompt: {prompt}.\nmodel\n" else: formatted_prompt = f"system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}\nuser\n\n\n\nUser Raw Input Prompt: {prompt}.\nmodel\n" - return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template) + return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio) class TextgenExtension(ComfyExtension): diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 5c096c232..719acf2f1 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -17,7 +17,8 @@ class SaveWEBM(io.ComfyNode): return io.Schema( node_id="SaveWEBM", search_aliases=["export webm"], - category="image/video", + display_name="Save WEBM", + category="video", is_experimental=True, inputs=[ io.Image.Input("images"), @@ -72,7 +73,7 @@ class SaveVideo(io.ComfyNode): node_id="SaveVideo", search_aliases=["export video"], display_name="Save Video", - category="image/video", + category="video", essentials_category="Basics", description="Saves the input images to your ComfyUI output directory.", inputs=[ @@ -121,7 +122,7 @@ class CreateVideo(io.ComfyNode): node_id="CreateVideo", search_aliases=["images to video"], display_name="Create Video", - category="image/video", + category="video", description="Create a video from images.", inputs=[ io.Image.Input("images", tooltip="The images to create a video from."), @@ -146,7 +147,7 @@ class GetVideoComponents(io.ComfyNode): node_id="GetVideoComponents", search_aliases=["extract frames", "split video", "video to images", "demux"], display_name="Get Video Components", - category="image/video", + category="video", description="Extracts all components from a video: frames, audio, and framerate.", inputs=[ io.Video.Input("video", tooltip="The video to extract components from."), @@ -174,7 +175,7 @@ class LoadVideo(io.ComfyNode): node_id="LoadVideo", search_aliases=["import video", "open video", "video file"], display_name="Load Video", - category="image/video", + category="video", essentials_category="Basics", inputs=[ io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video), @@ -216,7 +217,7 @@ class VideoSlice(io.ComfyNode): "frame load cap", "start time", ], - category="image/video", + category="video", essentials_category="Video Tools", inputs=[ io.Video.Input("video"), diff --git a/execution.py b/execution.py index 5a6d3404c..f37d0360d 100644 --- a/execution.py +++ b/execution.py @@ -15,6 +15,7 @@ import torch from comfy.cli_args import args import comfy.memory_management import comfy.model_management +import comfy.model_prefetch import comfy_aimdo.model_vbar from latent_preview import set_preview_method @@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if args.verbose == "DEBUG": comfy_aimdo.control.analyze() comfy.model_management.reset_cast_buffers() + comfy.model_prefetch.cleanup_prefetch_queues() comfy_aimdo.model_vbar.vbars_reset_watermark_limits() if has_pending_tasks: @@ -1017,7 +1019,12 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): combo_options = extra_info.get("options", []) else: combo_options = input_type - if val not in combo_options: + is_multiselect = extra_info.get("multiselect", False) + if is_multiselect and isinstance(val, list): + invalid_vals = [v for v in val if v not in combo_options] + else: + invalid_vals = [val] if val not in combo_options else [] + if invalid_vals: input_config = info list_info = "" @@ -1032,7 +1039,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): error = { "type": "value_not_in_list", "message": "Value not in list", - "details": f"{x}: '{val}' not in {list_info}", + "details": f"{x}: {', '.join(repr(v) for v in invalid_vals)} not in {list_info}", "extra_info": { "input_name": x, "input_config": input_config, diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index 34df01681..9c395c0b2 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -28,7 +28,7 @@ #config for a1111 ui #all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed -#a111: +#a1111: # base_path: path/to/stable-diffusion-webui/ # checkpoints: models/Stable-diffusion # configs: models/Stable-diffusion diff --git a/main.py b/main.py index dbaf2745c..a6fdaf43c 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,21 @@ import comfy.options comfy.options.enable_args_parsing() +from comfy.cli_args import args + +if args.list_feature_flags: + import json + from comfy_api.feature_flags import CLI_FEATURE_FLAG_REGISTRY + print(json.dumps(CLI_FEATURE_FLAG_REGISTRY, indent=2)) # noqa: T201 + raise SystemExit(0) + import os import importlib.util import shutil import importlib.metadata import folder_paths import time -from comfy.cli_args import args, enables_dynamic_vram +from comfy.cli_args import enables_dynamic_vram from app.logger import setup_logger setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) diff --git a/node_helpers.py b/node_helpers.py index d3d834516..cac4e88dd 100644 --- a/node_helpers.py +++ b/node_helpers.py @@ -86,6 +86,6 @@ def image_alpha_fix(destination, source): if destination.shape[-1] < source.shape[-1]: source = source[...,:destination.shape[-1]] elif destination.shape[-1] > source.shape[-1]: - destination = torch.nn.functional.pad(destination, (0, 1)) - destination[..., -1] = 1.0 + source = torch.nn.functional.pad(source, (0, 1)) + source[..., -1] = 1.0 return destination, source diff --git a/nodes.py b/nodes.py index 99dc07227..cf61d9df0 100644 --- a/nodes.py +++ b/nodes.py @@ -1694,26 +1694,27 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" + def load_image(self, image): image_path = folder_paths.get_annotated_filepath(image) + dtype = comfy.model_management.intermediate_dtype() + device = comfy.model_management.intermediate_device() + components = InputImpl.VideoFromFile(image_path).get_components() if components.images.shape[0] > 0: - return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu")) + return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device)) + # This code is left here to handle animated webp which pyav does not support loading img = node_helpers.pillow(Image.open, image_path) output_images = [] output_masks = [] w, h = None, None - dtype = comfy.model_management.intermediate_dtype() - for i in ImageSequence.Iterator(img): i = node_helpers.pillow(ImageOps.exif_transpose, i) - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) image = i.convert("RGB") if len(output_images) == 0: @@ -1728,25 +1729,15 @@ class LoadImage: if 'A' in i.getbands(): mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) - elif i.mode == 'P' and 'transparency' in i.info: - mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 - mask = 1. - torch.from_numpy(mask) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") + mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") output_images.append(image.to(dtype=dtype)) output_masks.append(mask.unsqueeze(0).to(dtype=dtype)) - if img.format == "MPO": - break # ignore all frames except the first one for MPO format + output_image = torch.cat(output_images, dim=0) + output_mask = torch.cat(output_masks, dim=0) - if len(output_images) > 1: - output_image = torch.cat(output_images, dim=0) - output_mask = torch.cat(output_masks, dim=0) - else: - output_image = output_images[0] - output_mask = output_masks[0] - - return (output_image, output_mask) + return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype)) @classmethod def IS_CHANGED(s, image): @@ -1763,57 +1754,49 @@ class LoadImage: return True -class LoadImageMask: + +class LoadImageMask(LoadImage): ESSENTIALS_CATEGORY = "Image Tools" SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] _color_channels = ["alpha", "red", "green", "blue"] + @classmethod def INPUT_TYPES(s): - input_dir = folder_paths.get_input_directory() - files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] - return {"required": - {"image": (sorted(files), {"image_upload": True}), - "channel": (s._color_channels, ), } - } + types = super().INPUT_TYPES() + return { + "required": { + **types["required"], + "channel": (s._color_channels, ) + } + } CATEGORY = "mask" - RETURN_TYPES = ("MASK",) - FUNCTION = "load_image" - def load_image(self, image, channel): - image_path = folder_paths.get_annotated_filepath(image) - i = node_helpers.pillow(Image.open, image_path) - i = node_helpers.pillow(ImageOps.exif_transpose, i) - if i.getbands() != ("R", "G", "B", "A"): - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) - i = i.convert("RGBA") - mask = None + FUNCTION = "load_image_mask" + + def load_image_mask(self, image, channel): + image_tensor, mask_tensor = super().load_image(image) c = channel[0].upper() - if c in i.getbands(): - mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0 - mask = torch.from_numpy(mask) - if c == 'A': - mask = 1. - mask + + if c == 'A': + return (mask_tensor,) + + channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0) + + if channel_idx < image_tensor.shape[-1]: + return (image_tensor[..., channel_idx].clone(),) else: - mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") - return (mask.unsqueeze(0),) + empty_mask = torch.zeros( + image_tensor.shape[:-1], + dtype=image_tensor.dtype, + device=image_tensor.device + ) + return (empty_mask,) @classmethod def IS_CHANGED(s, image, channel): - image_path = folder_paths.get_annotated_filepath(image) - m = hashlib.sha256() - with open(image_path, 'rb') as f: - m.update(f.read()) - return m.digest().hex() - - @classmethod - def VALIDATE_INPUTS(s, image): - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) - - return True + return super().IS_CHANGED(image) class LoadImageOutput(LoadImage): @@ -1904,7 +1887,7 @@ class ImageInvert: RETURN_TYPES = ("IMAGE",) FUNCTION = "invert" - CATEGORY = "image" + CATEGORY = "image/color" def invert(self, image): s = 1.0 - image @@ -1920,7 +1903,7 @@ class ImageBatch: RETURN_TYPES = ("IMAGE",) FUNCTION = "batch" - CATEGORY = "image" + CATEGORY = "image/batch" DEPRECATED = True def batch(self, image1, image2): @@ -1977,7 +1960,7 @@ class ImagePadForOutpaint: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "expand_image" - CATEGORY = "image" + CATEGORY = "image/transform" def expand_image(self, image, left, top, right, bottom, feathering): d1, d2, d3, d4 = image.size() @@ -2120,7 +2103,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)", "ConditioningSetMask": "Conditioning (Set Mask)", - "ControlNetApply": "Apply ControlNet (OLD)", + "ControlNetApply": "Apply ControlNet (DEPRECATED)", "ControlNetApplyAdvanced": "Apply ControlNet", # Latent "VAEEncodeForInpaint": "VAE Encode (for Inpainting)", @@ -2138,6 +2121,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentFromBatch" : "Latent From Batch", "RepeatLatentBatch": "Repeat Latent Batch", # Image + "EmptyImage": "Empty Image", "SaveImage": "Save Image", "PreviewImage": "Preview Image", "LoadImage": "Load Image", @@ -2145,15 +2129,15 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoadImageOutput": "Load Image (from Outputs)", "ImageScale": "Upscale Image", "ImageScaleBy": "Upscale Image By", - "ImageInvert": "Invert Image", + "ImageInvert": "Invert Image Colors", "ImagePadForOutpaint": "Pad Image for Outpainting", - "ImageBatch": "Batch Images", - "ImageCrop": "Image Crop", - "ImageStitch": "Image Stitch", - "ImageBlend": "Image Blend", - "ImageBlur": "Image Blur", - "ImageQuantize": "Image Quantize", - "ImageSharpen": "Image Sharpen", + "ImageBatch": "Batch Images (DEPRECATED)", + "ImageCrop": "Crop Image", + "ImageStitch": "Stitch Images", + "ImageBlend": "Blend Images", + "ImageBlur": "Blur Image", + "ImageQuantize": "Quantize Image", + "ImageSharpen": "Sharpen Image", "ImageScaleToTotalPixels": "Scale Image to Total Pixels", "GetImageSize": "Get Image Size", # _for_testing @@ -2278,7 +2262,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}") return False else: - logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).") + logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or comfy_entrypoint (need one).") return False except Exception as e: logging.warning(traceback.format_exc()) @@ -2428,6 +2412,7 @@ async def init_builtin_extra_nodes(): "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", + "nodes_ar_video.py", "nodes_image_compare.py", "nodes_zimage.py", "nodes_glsl.py", diff --git a/openapi.yaml b/openapi.yaml index a0736a529..29b5f544b 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -2467,7 +2467,12 @@ components: description: Device type (cuda, mps, cpu, etc.) index: type: number - description: Device index + nullable: true + description: | + Device index within its type (e.g. CUDA ordinal for `cuda:0`, + `cuda:1`). `null` for devices with no index, including the CPU + device returned in `--cpu` mode (PyTorch's `torch.device('cpu').index` + is `None`). vram_total: type: number description: Total VRAM in bytes @@ -2623,7 +2628,18 @@ components: description: Alternative search terms for finding this node essentials_category: type: string - description: Category override used by the essentials pack + nullable: true + description: | + Category override used by the essentials pack. The + `essentials_category` key may be present with a string value, + present and `null`, or absent entirely: + + - V1 nodes: `essentials_category` is **omitted** when the node + class doesn't define an `ESSENTIALS_CATEGORY` attribute, and + **`null`** if the attribute is explicitly set to `None`. + - V3 nodes (`comfy_api.latest.io`): `essentials_category` is + **always present**, and **`null`** for nodes whose `Schema` + doesn't populate it. # ------------------------------------------------------------------- # Models diff --git a/requirements.txt b/requirements.txt index 932034076..e9415f2fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.42.15 -comfyui-workflow-templates==0.9.66 +comfyui-workflow-templates==0.9.69 comfyui-embedded-docs==0.4.4 torch torchsde diff --git a/server.py b/server.py index 881da8e66..0e85635d3 100644 --- a/server.py +++ b/server.py @@ -1,3 +1,4 @@ +import errno import os import sys import asyncio @@ -559,7 +560,7 @@ class PromptServer(): buffer.seek(0) return web.Response(body=buffer.read(), content_type=f'image/{image_format}', - headers={"Content-Disposition": f"filename=\"{filename}\""}) + headers={"Content-Disposition": f"attachment; filename=\"{filename}\""}) if 'channel' not in request.rel_url.query: channel = 'rgba' @@ -579,7 +580,7 @@ class PromptServer(): buffer.seek(0) return web.Response(body=buffer.read(), content_type='image/png', - headers={"Content-Disposition": f"filename=\"{filename}\""}) + headers={"Content-Disposition": f"attachment; filename=\"{filename}\""}) elif channel == 'a': with Image.open(file) as img: @@ -596,7 +597,7 @@ class PromptServer(): alpha_buffer.seek(0) return web.Response(body=alpha_buffer.read(), content_type='image/png', - headers={"Content-Disposition": f"filename=\"{filename}\""}) + headers={"Content-Disposition": f"attachment; filename=\"{filename}\""}) else: # Use the content type from asset resolution if available, # otherwise guess from the filename. @@ -613,7 +614,7 @@ class PromptServer(): return web.FileResponse( file, headers={ - "Content-Disposition": f"filename=\"{filename}\"", + "Content-Disposition": f"attachment; filename=\"{filename}\"", "Content-Type": content_type } ) @@ -1245,7 +1246,13 @@ class PromptServer(): address = addr[0] port = addr[1] site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) - await site.start() + try: + await site.start() + except OSError as e: + if e.errno == errno.EADDRINUSE: + logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.") + raise SystemExit(1) + raise if not hasattr(self, 'address'): self.address = address #TODO: remove this diff --git a/tests-unit/comfy_api_test/multicombo_serialization_test.py b/tests-unit/comfy_api_test/multicombo_serialization_test.py new file mode 100644 index 000000000..421c65a0d --- /dev/null +++ b/tests-unit/comfy_api_test/multicombo_serialization_test.py @@ -0,0 +1,78 @@ +from comfy_api.latest._io import Combo, MultiCombo + + +def test_multicombo_serializes_multi_select_as_object(): + multi_combo = MultiCombo.Input( + id="providers", + options=["a", "b", "c"], + default=["a"], + ) + + serialized = multi_combo.as_dict() + + assert serialized["multiselect"] is True + assert "multi_select" in serialized + assert serialized["multi_select"] == {} + + +def test_multicombo_serializes_multi_select_with_placeholder_and_chip(): + multi_combo = MultiCombo.Input( + id="providers", + options=["a", "b", "c"], + default=["a"], + placeholder="Select providers", + chip=True, + ) + + serialized = multi_combo.as_dict() + + assert serialized["multiselect"] is True + assert serialized["multi_select"] == { + "placeholder": "Select providers", + "chip": True, + } + + +def test_combo_does_not_serialize_multiselect(): + """Regular Combo should not have multiselect in its serialized output.""" + combo = Combo.Input( + id="choice", + options=["a", "b", "c"], + ) + + serialized = combo.as_dict() + + # Combo sets multiselect=False, but prune_dict keeps False (not None), + # so it should be present but False + assert serialized.get("multiselect") is False + assert "multi_select" not in serialized + + +def _validate_combo_values(val, combo_options, is_multiselect): + """Reproduce the validation logic from execution.py for testing.""" + if is_multiselect and isinstance(val, list): + return [v for v in val if v not in combo_options] + else: + return [val] if val not in combo_options else [] + + +def test_multicombo_validation_accepts_valid_list(): + options = ["a", "b", "c"] + assert _validate_combo_values(["a", "b"], options, True) == [] + + +def test_multicombo_validation_rejects_invalid_values(): + options = ["a", "b", "c"] + assert _validate_combo_values(["a", "x"], options, True) == ["x"] + + +def test_multicombo_validation_accepts_empty_list(): + options = ["a", "b", "c"] + assert _validate_combo_values([], options, True) == [] + + +def test_combo_validation_rejects_list_even_with_valid_items(): + """A regular Combo should not accept a list value.""" + options = ["a", "b", "c"] + invalid = _validate_combo_values(["a", "b"], options, False) + assert len(invalid) > 0 diff --git a/tests-unit/deploy_environment_test.py b/tests-unit/deploy_environment_test.py new file mode 100644 index 000000000..c3497fbb0 --- /dev/null +++ b/tests-unit/deploy_environment_test.py @@ -0,0 +1,109 @@ +"""Tests for comfy.deploy_environment.""" + +import os + +import pytest + +from comfy import deploy_environment +from comfy.deploy_environment import get_deploy_environment + + +@pytest.fixture(autouse=True) +def _reset_cache_and_install_dir(tmp_path, monkeypatch): + """Reset the functools cache and point the ComfyUI install dir at a tmp dir for each test.""" + get_deploy_environment.cache_clear() + monkeypatch.setattr(deploy_environment, "_COMFY_INSTALL_DIR", str(tmp_path)) + yield + get_deploy_environment.cache_clear() + + +def _write_env_file(tmp_path, content: str) -> str: + """Write the env file with exact content (no newline translation). + + `newline=""` disables Python's text-mode newline translation so the bytes + on disk match the literal string passed in, regardless of host OS. + Newline-style tests (CRLF, lone CR) rely on this. + """ + path = os.path.join(str(tmp_path), ".comfy_environment") + with open(path, "w", encoding="utf-8", newline="") as f: + f.write(content) + return path + + +class TestGetDeployEnvironment: + def test_returns_local_git_when_file_missing(self): + assert get_deploy_environment() == "local-git" + + def test_reads_value_from_file(self, tmp_path): + _write_env_file(tmp_path, "local-desktop2-standalone\n") + assert get_deploy_environment() == "local-desktop2-standalone" + + def test_strips_trailing_whitespace_and_newline(self, tmp_path): + _write_env_file(tmp_path, " local-desktop2-standalone \n") + assert get_deploy_environment() == "local-desktop2-standalone" + + def test_only_first_line_is_used(self, tmp_path): + _write_env_file(tmp_path, "first-line\nsecond-line\n") + assert get_deploy_environment() == "first-line" + + def test_crlf_line_ending(self, tmp_path): + # Windows editors often save text files with CRLF line endings. + # The CR must not end up in the returned value. + _write_env_file(tmp_path, "local-desktop2-standalone\r\n") + assert get_deploy_environment() == "local-desktop2-standalone" + + def test_crlf_multiline_only_first_line_used(self, tmp_path): + _write_env_file(tmp_path, "first-line\r\nsecond-line\r\n") + assert get_deploy_environment() == "first-line" + + def test_crlf_with_surrounding_whitespace(self, tmp_path): + _write_env_file(tmp_path, " local-desktop2-standalone \r\n") + assert get_deploy_environment() == "local-desktop2-standalone" + + def test_lone_cr_line_ending(self, tmp_path): + # Classic-Mac / some legacy editors use a bare CR. + # Universal-newlines decoding treats it as a line terminator too. + _write_env_file(tmp_path, "local-desktop2-standalone\r") + assert get_deploy_environment() == "local-desktop2-standalone" + + def test_empty_file_falls_back_to_default(self, tmp_path): + _write_env_file(tmp_path, "") + assert get_deploy_environment() == "local-git" + + def test_empty_after_whitespace_strip_falls_back_to_default(self, tmp_path): + _write_env_file(tmp_path, " \n") + assert get_deploy_environment() == "local-git" + + def test_strips_control_chars_within_first_line(self, tmp_path): + # Embedded NUL/control chars in the value should be stripped + # (header-injection / smuggling protection). + _write_env_file(tmp_path, "abc\x00\x07xyz\n") + assert get_deploy_environment() == "abcxyz" + + def test_strips_non_ascii_characters(self, tmp_path): + _write_env_file(tmp_path, "café-é\n") + assert get_deploy_environment() == "caf-" + + def test_caps_read_at_128_bytes(self, tmp_path): + # A single huge line with no newline must not be fully read into memory. + huge = "x" * 10_000 + _write_env_file(tmp_path, huge) + result = get_deploy_environment() + assert result == "x" * 128 + + def test_result_is_cached_across_calls(self, tmp_path): + path = _write_env_file(tmp_path, "first_value\n") + assert get_deploy_environment() == "first_value" + # Overwrite the file — cached value should still be returned. + with open(path, "w", encoding="utf-8") as f: + f.write("second_value\n") + assert get_deploy_environment() == "first_value" + + def test_unreadable_file_falls_back_to_default(self, tmp_path, monkeypatch): + _write_env_file(tmp_path, "should_not_be_used\n") + + def _boom(*args, **kwargs): + raise OSError("simulated read failure") + + monkeypatch.setattr("builtins.open", _boom) + assert get_deploy_environment() == "local-git" diff --git a/tests-unit/feature_flags_test.py b/tests-unit/feature_flags_test.py index f2702cfc8..8ec52a124 100644 --- a/tests-unit/feature_flags_test.py +++ b/tests-unit/feature_flags_test.py @@ -1,10 +1,15 @@ """Tests for feature flags functionality.""" +import pytest + from comfy_api.feature_flags import ( get_connection_feature, supports_feature, get_server_features, + CLI_FEATURE_FLAG_REGISTRY, SERVER_FEATURE_FLAGS, + _coerce_flag_value, + _parse_cli_feature_flags, ) @@ -96,3 +101,83 @@ class TestFeatureFlags: result = get_connection_feature(sockets_metadata, "sid1", "any_feature") assert result is False assert supports_feature(sockets_metadata, "sid1", "any_feature") is False + + +class TestCoerceFlagValue: + """Test suite for _coerce_flag_value.""" + + def test_registered_bool_true(self): + assert _coerce_flag_value("show_signin_button", "true") is True + assert _coerce_flag_value("show_signin_button", "True") is True + + def test_registered_bool_false(self): + assert _coerce_flag_value("show_signin_button", "false") is False + assert _coerce_flag_value("show_signin_button", "FALSE") is False + + def test_unregistered_key_stays_string(self): + assert _coerce_flag_value("unknown_flag", "true") == "true" + assert _coerce_flag_value("unknown_flag", "42") == "42" + + def test_bool_typo_raises(self): + """Strict bool: typos like 'ture' or 'yes' must raise so the flag can be dropped.""" + with pytest.raises(ValueError): + _coerce_flag_value("show_signin_button", "ture") + with pytest.raises(ValueError): + _coerce_flag_value("show_signin_button", "yes") + with pytest.raises(ValueError): + _coerce_flag_value("show_signin_button", "1") + with pytest.raises(ValueError): + _coerce_flag_value("show_signin_button", "") + + def test_failed_int_coercion_raises(self, monkeypatch): + """Malformed values for typed flags must raise; caller decides what to do.""" + monkeypatch.setitem( + CLI_FEATURE_FLAG_REGISTRY, + "test_int_flag", + {"type": "int", "default": 0, "description": "test"}, + ) + with pytest.raises(ValueError): + _coerce_flag_value("test_int_flag", "not_a_number") + + +class TestParseCliFeatureFlags: + """Test suite for _parse_cli_feature_flags.""" + + def test_single_flag(self, monkeypatch): + monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["show_signin_button=true"]})()) + result = _parse_cli_feature_flags() + assert result == {"show_signin_button": True} + + def test_missing_equals_defaults_to_true(self, monkeypatch): + """Bare flag without '=' is treated as the string 'true' (and coerced if registered).""" + monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["show_signin_button", "valid=1"]})()) + result = _parse_cli_feature_flags() + assert result == {"show_signin_button": True, "valid": "1"} + + def test_empty_key_skipped(self, monkeypatch): + monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["=value", "valid=1"]})()) + result = _parse_cli_feature_flags() + assert result == {"valid": "1"} + + def test_invalid_bool_value_dropped(self, monkeypatch, caplog): + """A typo'd bool value must be dropped entirely, not silently set to False + and not stored as a raw string. A warning must be logged.""" + monkeypatch.setattr( + "comfy_api.feature_flags.args", + type("Args", (), {"feature_flag": ["show_signin_button=ture", "valid=1"]})(), + ) + with caplog.at_level("WARNING"): + result = _parse_cli_feature_flags() + assert result == {"valid": "1"} + assert "show_signin_button" not in result + assert any("show_signin_button" in r.message and "drop" in r.message.lower() for r in caplog.records) + + +class TestCliFeatureFlagRegistry: + """Test suite for the CLI feature flag registry.""" + + def test_registry_entries_have_required_fields(self): + for key, info in CLI_FEATURE_FLAG_REGISTRY.items(): + assert "type" in info, f"{key} missing 'type'" + assert "default" in info, f"{key} missing 'default'" + assert "description" in info, f"{key} missing 'description'"