diff --git a/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat b/.ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat
similarity index 66%
rename from .ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat
rename to .ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat
index cece0aeb2..94ad31942 100755
--- a/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat
+++ b/.ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat
@@ -1,2 +1,2 @@
-.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
+.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-dynamic-vram
pause
diff --git a/.github/workflows/openapi-lint.yml b/.github/workflows/openapi-lint.yml
new file mode 100644
index 000000000..be949de2a
--- /dev/null
+++ b/.github/workflows/openapi-lint.yml
@@ -0,0 +1,31 @@
+name: OpenAPI Lint
+
+on:
+ pull_request:
+ paths:
+ - 'openapi.yaml'
+ - '.spectral.yaml'
+ - '.github/workflows/openapi-lint.yml'
+
+permissions:
+ contents: read
+
+jobs:
+ spectral:
+ name: Run Spectral
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Set up Node.js
+ uses: actions/setup-node@v4
+ with:
+ node-version: '20'
+
+ - name: Install Spectral
+ run: npm install -g @stoplight/spectral-cli@6
+
+ - name: Lint openapi.yaml
+ run: spectral lint openapi.yaml --ruleset .spectral.yaml --fail-severity=error
diff --git a/.gitignore b/.gitignore
index 0ab4ba75e..fc426eda4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -23,3 +23,4 @@ web_custom_versions/
.DS_Store
filtered-openapi.yaml
uv.lock
+.comfy_environment
diff --git a/.spectral.yaml b/.spectral.yaml
new file mode 100644
index 000000000..4bb4a4a94
--- /dev/null
+++ b/.spectral.yaml
@@ -0,0 +1,91 @@
+extends:
+ - spectral:oas
+
+# Severity levels: error, warn, info, hint, off
+# Rules from the built-in "spectral:oas" ruleset are active by default.
+# Below we tune severity and add custom rules for our conventions.
+#
+# This ruleset mirrors Comfy-Org/cloud/.spectral.yaml so specs across the
+# organization are linted against a single consistent standard.
+
+rules:
+ # -----------------------------------------------------------------------
+ # Built-in rule severity overrides
+ # -----------------------------------------------------------------------
+ operation-operationId: error
+ operation-description: warn
+ operation-tag-defined: error
+ info-contact: off
+ info-description: warn
+ no-eval-in-markdown: error
+ no-$ref-siblings: error
+
+ # -----------------------------------------------------------------------
+ # Custom rules: naming conventions
+ # -----------------------------------------------------------------------
+
+ # Property names should be snake_case
+ property-name-snake-case:
+ description: Property names must be snake_case
+ severity: warn
+ given: "$.components.schemas.*.properties[*]~"
+ then:
+ function: pattern
+ functionOptions:
+ match: "^[a-z][a-z0-9]*(_[a-z0-9]+)*$"
+
+ # Operation IDs should be camelCase
+ operation-id-camel-case:
+ description: Operation IDs must be camelCase
+ severity: warn
+ given: "$.paths.*.*.operationId"
+ then:
+ function: pattern
+ functionOptions:
+ match: "^[a-z][a-zA-Z0-9]*$"
+
+ # -----------------------------------------------------------------------
+ # Custom rules: response conventions
+ # -----------------------------------------------------------------------
+
+ # Error responses (4xx, 5xx) should use a consistent shape
+ error-response-schema:
+ description: Error responses should reference a standard error schema
+ severity: hint
+ given: "$.paths.*.*.responses[?(@property >= '400' && @property < '600')].content['application/json'].schema"
+ then:
+ field: "$ref"
+ function: truthy
+
+ # All 2xx responses with JSON body should have a schema
+ response-schema-defined:
+ description: Success responses with JSON content should define a schema
+ severity: warn
+ given: "$.paths.*.*.responses[?(@property >= '200' && @property < '300')].content['application/json']"
+ then:
+ field: schema
+ function: truthy
+
+ # -----------------------------------------------------------------------
+ # Custom rules: best practices
+ # -----------------------------------------------------------------------
+
+ # Path parameters must have a description
+ path-param-description:
+ description: Path parameters should have a description
+ severity: warn
+ given:
+ - "$.paths.*.parameters[?(@.in == 'path')]"
+ - "$.paths.*.*.parameters[?(@.in == 'path')]"
+ then:
+ field: description
+ function: truthy
+
+ # Schemas should have a description
+ schema-description:
+ description: Component schemas should have a description
+ severity: hint
+ given: "$.components.schemas.*"
+ then:
+ field: description
+ function: truthy
diff --git a/CODEOWNERS b/CODEOWNERS
index 171e16855..ba8690509 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,5 +1,5 @@
# Admins
-* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128
+* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai
# Partner nodes team maintains API nodes
/comfy_api_nodes/ @Comfy-Org/partner-nodes-mergers
diff --git a/README.md b/README.md
index f05311421..0fd317d0a 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
# ComfyUI
-**The most powerful and modular visual AI engine and application.**
+**The most powerful and modular AI engine for content creation.**
[![Website][website-shield]][website-url]
@@ -31,10 +31,16 @@
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
-
+
+
-ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
+ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
+- ComfyUI natively supports the latest open-source state of the art models.
+- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
+- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
+- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
+- It integrates seamlessly into production pipelines with our API endpoints.
## Get Started
@@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
+ - Ernie Image
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
@@ -126,7 +133,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- - Releases a new stable version (e.g., v0.7.0) roughly every week.
+ - Releases a new major stable version (e.g., v0.7.0) roughly every 2 weeks.
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
- Minor versions will be used for releases off the master branch.
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
@@ -193,13 +200,15 @@ If you have trouble extracting it, right click the file -> properties -> unblock
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
-#### Alternative Downloads:
+#### All Official Portable Downloads:
[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
-[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
+[Portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
-[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
+[Portable for Nvidia GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) (supports 20 series and above).
+
+[Portable for Nvidia GPUs with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
diff --git a/app/user_manager.py b/app/user_manager.py
index e18afb71b..0517b3344 100644
--- a/app/user_manager.py
+++ b/app/user_manager.py
@@ -28,8 +28,8 @@ def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
- "modified": os.path.getmtime(path),
- "created": os.path.getctime(path)
+ "modified": int(os.path.getmtime(path) * 1000),
+ "created": int(os.path.getctime(path) * 1000),
}
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index dbaadf723..9dadb0093 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -90,8 +90,8 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
-parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
+parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
@@ -238,6 +238,8 @@ database_default_path = os.path.abspath(
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
+parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
+parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
if comfy.options.args_parsing:
args = parser.parse_args()
diff --git a/comfy/context_windows.py b/comfy/context_windows.py
index cb44ee6e8..db57537a2 100644
--- a/comfy/context_windows.py
+++ b/comfy/context_windows.py
@@ -63,7 +63,11 @@ class IndexListContextWindow(ContextWindowABC):
dim = self.dim
if dim == 0 and full.shape[dim] == 1:
return full
- idx = tuple([slice(None)] * dim + [self.index_list])
+ indices = self.index_list
+ anchor_idx = getattr(self, 'causal_anchor_index', None)
+ if anchor_idx is not None and anchor_idx >= 0:
+ indices = [anchor_idx] + list(indices)
+ idx = tuple([slice(None)] * dim + [indices])
window = full[idx]
if retain_index_list:
idx = tuple([slice(None)] * dim + [retain_index_list])
@@ -113,7 +117,14 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
if temporal_offset > 0:
- indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
+ anchor_idx = getattr(window, 'causal_anchor_index', None)
+ if anchor_idx is not None and anchor_idx >= 0:
+ # anchor occupies one of the no-cond positions, so skip one fewer from window.index_list
+ skip_count = temporal_offset - 1
+ else:
+ skip_count = temporal_offset
+
+ indices = [i - temporal_offset for i in window.index_list[skip_count:]]
indices = [i for i in indices if 0 <= i]
else:
indices = list(window.index_list)
@@ -150,7 +161,8 @@ class ContextFuseMethod:
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
- closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
+ closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
+ causal_window_fix: bool=True):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
@@ -162,6 +174,7 @@ class IndexListContextHandler(ContextHandlerABC):
self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows
+ self.causal_window_fix = causal_window_fix
self.callbacks = {}
@@ -318,6 +331,14 @@ class IndexListContextHandler(ContextHandlerABC):
# allow processing to end between context window executions for faster Cancel
comfy.model_management.throw_exception_if_processing_interrupted()
+ # causal_window_fix: prepend a pre-window frame that will be stripped post-forward
+ anchor_applied = False
+ if self.causal_window_fix:
+ anchor_idx = window.index_list[0] - 1
+ if 0 <= anchor_idx < x_in.size(self.dim):
+ window.causal_anchor_index = anchor_idx
+ anchor_applied = True
+
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
@@ -332,6 +353,12 @@ class IndexListContextHandler(ContextHandlerABC):
if device is not None:
for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
+
+ # strip causal_window_fix anchor if applied
+ if anchor_applied:
+ for i in range(len(sub_conds_out)):
+ sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
+
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
return results
diff --git a/comfy/deploy_environment.py b/comfy/deploy_environment.py
new file mode 100644
index 000000000..8c99a3584
--- /dev/null
+++ b/comfy/deploy_environment.py
@@ -0,0 +1,34 @@
+import functools
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+_DEFAULT_DEPLOY_ENV = "local-git"
+_ENV_FILENAME = ".comfy_environment"
+
+# Resolve the ComfyUI install directory (the parent of this `comfy/` package).
+# We deliberately avoid `folder_paths.base_path` here because that is overridden
+# by the `--base-directory` CLI arg to a user-supplied path, whereas the
+# `.comfy_environment` marker is written by launchers/installers next to the
+# ComfyUI install itself.
+_COMFY_INSTALL_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+
+
+@functools.cache
+def get_deploy_environment() -> str:
+ env_file = os.path.join(_COMFY_INSTALL_DIR, _ENV_FILENAME)
+ try:
+ with open(env_file, encoding="utf-8") as f:
+ # Cap the read so a malformed or maliciously crafted file (e.g.
+ # a single huge line with no newline) can't blow up memory.
+ first_line = f.readline(128).strip()
+ value = "".join(c for c in first_line if 32 <= ord(c) < 127)
+ if value:
+ return value
+ except FileNotFoundError:
+ pass
+ except Exception as e:
+ logger.error("Failed to read %s: %s", env_file, e)
+
+ return _DEFAULT_DEPLOY_ENV
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index 6978eb717..d33bc7199 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -1810,3 +1810,102 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
+
+
+@torch.no_grad()
+def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None,
+ num_frame_per_block=1):
+ """
+ Autoregressive video sampler: block-by-block denoising with KV cache
+ and flow-match re-noising for Causal Forcing / Self-Forcing models.
+
+ Requires a Causal-WAN compatible model (diffusion_model must expose
+ init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
+
+ All AR-loop parameters are passed via the SamplerARVideo node, not read
+ from the checkpoint or transformer_options.
+ """
+ extra_args = {} if extra_args is None else extra_args
+ model_options = extra_args.get("model_options", {})
+ transformer_options = model_options.get("transformer_options", {})
+
+ if x.ndim != 5:
+ raise ValueError(
+ f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
+ "This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
+ )
+
+ inner_model = model.inner_model.inner_model
+ causal_model = inner_model.diffusion_model
+
+ if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
+ raise TypeError(
+ "ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
+ "exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
+ "does not support this interface — choose a different sampler."
+ )
+
+ seed = extra_args.get("seed", 0)
+
+ bs, c, lat_t, lat_h, lat_w = x.shape
+ frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
+ num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
+ device = x.device
+ model_dtype = inner_model.get_dtype()
+
+ kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
+ crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
+
+ output = torch.zeros_like(x)
+ s_in = x.new_ones([x.shape[0]])
+ current_start_frame = 0
+ num_sigma_steps = len(sigmas) - 1
+ total_real_steps = num_blocks * num_sigma_steps
+ step_count = 0
+
+ try:
+ for block_idx in trange(num_blocks, disable=disable):
+ bf = min(num_frame_per_block, lat_t - current_start_frame)
+ fs, fe = current_start_frame, current_start_frame + bf
+ noisy_input = x[:, :, fs:fe]
+
+ ar_state = {
+ "start_frame": current_start_frame,
+ "kv_caches": kv_caches,
+ "crossattn_caches": crossattn_caches,
+ }
+ transformer_options["ar_state"] = ar_state
+
+ for i in range(num_sigma_steps):
+ denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
+
+ if callback is not None:
+ scaled_i = step_count * num_sigma_steps // total_real_steps
+ callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
+ "sigma_hat": sigmas[i], "denoised": denoised})
+
+ if sigmas[i + 1] == 0:
+ noisy_input = denoised
+ else:
+ sigma_next = sigmas[i + 1]
+ torch.manual_seed(seed + block_idx * 1000 + i)
+ fresh_noise = torch.randn_like(denoised)
+ noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
+
+ for cache in kv_caches:
+ cache["end"] -= bf * frame_seq_len
+
+ step_count += 1
+
+ output[:, :, fs:fe] = noisy_input
+
+ for cache in kv_caches:
+ cache["end"] -= bf * frame_seq_len
+ zero_sigma = sigmas.new_zeros([1])
+ _ = model(noisy_input, zero_sigma * s_in, **extra_args)
+
+ current_start_frame += bf
+ finally:
+ transformer_options.pop("ar_state", None)
+
+ return output
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index 3dac5be18..91bebed3d 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -9,6 +9,7 @@ class LatentFormat:
latent_rgb_factors_reshape = None
taesd_decoder_name = None
spacial_downscale_ratio = 8
+ temporal_downscale_ratio = 1
def process_in(self, latent):
return latent * self.scale_factor
@@ -235,6 +236,7 @@ class Flux2(LatentFormat):
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3
+ temporal_downscale_ratio = 6
def __init__(self):
self.scale_factor = 1.0
@@ -278,6 +280,7 @@ class LTXV(LatentFormat):
latent_channels = 128
latent_dimensions = 3
spacial_downscale_ratio = 32
+ temporal_downscale_ratio = 8
def __init__(self):
self.latent_rgb_factors = [
@@ -421,6 +424,7 @@ class LTXAV(LTXV):
class HunyuanVideo(LatentFormat):
latent_channels = 16
latent_dimensions = 3
+ temporal_downscale_ratio = 4
scale_factor = 0.476986
latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445],
@@ -447,6 +451,7 @@ class HunyuanVideo(LatentFormat):
class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16
latent_dimensions = 3
+ temporal_downscale_ratio = 8
latent_rgb_factors = [
[ 0.1817, 0.2284, 0.2423],
@@ -472,6 +477,7 @@ class Cosmos1CV8x8x8(LatentFormat):
class Wan21(LatentFormat):
latent_channels = 16
latent_dimensions = 3
+ temporal_downscale_ratio = 4
latent_rgb_factors = [
[-0.1299, -0.1692, 0.2932],
@@ -734,6 +740,7 @@ class HunyuanVideo15(LatentFormat):
latent_channels = 32
latent_dimensions = 3
spacial_downscale_ratio = 16
+ temporal_downscale_ratio = 4
scale_factor = 1.03682
taesd_decoder_name = "lighttaehy1_5"
@@ -786,8 +793,27 @@ class ZImagePixelSpace(ChromaRadiance):
pass
class CogVideoX(LatentFormat):
+ """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
+
+ scale_factor matches the vae/config.json scaling_factor for the 2b variant.
+ The 5b-class checkpoints (CogVideoX-5b, CogVideoX-1.5-5B, CogVideoX-Fun-V1.5-*)
+ use a different value; see CogVideoX1_5 below.
+ """
latent_channels = 16
latent_dimensions = 3
+ temporal_downscale_ratio = 4
def __init__(self):
self.scale_factor = 1.15258426
+
+
+class CogVideoX1_5(CogVideoX):
+ """Latent format for 5b-class CogVideoX checkpoints.
+
+ Covers THUDM/CogVideoX-5b, THUDM/CogVideoX-1.5-5B, and the CogVideoX-Fun
+ V1.5-5b family (including VOID inpainting). All of these have
+ scaling_factor=0.7 in their vae/config.json. Auto-selected in
+ supported_models.CogVideoX_T2V based on transformer hidden dim.
+ """
+ def __init__(self):
+ self.scale_factor = 0.7
diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py
index 6f2ba41ef..3fb87b4a3 100644
--- a/comfy/ldm/lightricks/av_model.py
+++ b/comfy/ldm/lightricks/av_model.py
@@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit
+import comfy.model_prefetch
class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
@@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
+ prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
# Process transformer blocks
for i, block in enumerate(self.transformer_blocks):
+ comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
@@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
a_prompt_timestep=a_prompt_timestep,
)
+ comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
+
return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index b193fe5e8..a68cb8439 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
+TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
+
if model_management.xformers_enabled():
import xformers
import xformers.ops
@@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
- scale = dim_head ** -0.5
+ if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
+ n_rep = q.shape[-3] // k.shape[-3]
+ k = k.repeat_interleave(n_rep, dim=-3)
+ v = v.repeat_interleave(n_rep, dim=-3)
+
+ scale = kwargs.get("scale", dim_head ** -0.5)
h = heads
if skip_reshape:
@@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape
dim_head //= heads
+ if "scale" in kwargs:
+ # Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
+ query = query * (kwargs["scale"] * dim_head ** 0.5)
+
if skip_reshape:
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
@@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
- scale = dim_head ** -0.5
+ scale = kwargs.get("scale", dim_head ** -0.5)
if skip_reshape:
q, k, v = map(
@@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3:
mask = mask.unsqueeze(1)
+ # Pass through extra SDPA kwargs (scale, enable_gqa) if provided
+ # enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
+ sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
+ sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
+
if SDP_BATCH_LIMIT >= b:
- out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
+ out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=m,
- dropout_p=0.0, is_causal=False
+ dropout_p=0.0, is_causal=False, **sdpa_extra
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out
diff --git a/comfy/ldm/wan/ar_model.py b/comfy/ldm/wan/ar_model.py
new file mode 100644
index 000000000..d72f53602
--- /dev/null
+++ b/comfy/ldm/wan/ar_model.py
@@ -0,0 +1,276 @@
+"""
+CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
+autoregressive (frame-by-frame) video generation via Causal Forcing.
+
+Weight-compatible with the standard WanModel -- same layer names, same shapes.
+The difference is purely in the forward pass: this model processes one temporal
+block at a time and maintains a KV cache across blocks.
+
+Reference: https://github.com/thu-ml/Causal-Forcing
+"""
+
+import torch
+import torch.nn as nn
+
+from comfy.ldm.modules.attention import optimized_attention
+from comfy.ldm.flux.math import apply_rope1
+from comfy.ldm.wan.model import (
+ sinusoidal_embedding_1d,
+ repeat_e,
+ WanModel,
+ WanAttentionBlock,
+)
+import comfy.ldm.common_dit
+import comfy.model_management
+
+
+class CausalWanSelfAttention(nn.Module):
+ """Self-attention with KV cache support for autoregressive inference."""
+
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
+ eps=1e-6, operation_settings={}):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.qk_norm = qk_norm
+ self.eps = eps
+
+ ops = operation_settings.get("operations")
+ device = operation_settings.get("device")
+ dtype = operation_settings.get("dtype")
+
+ self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
+ self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
+ self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
+ self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
+ self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
+ self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
+
+ def forward(self, x, freqs, kv_cache=None, transformer_options={}):
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+
+ q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
+ k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
+ v = self.v(x).view(b, s, n, d)
+
+ if kv_cache is None:
+ x = optimized_attention(
+ q.view(b, s, n * d),
+ k.view(b, s, n * d),
+ v.view(b, s, n * d),
+ heads=self.num_heads,
+ transformer_options=transformer_options,
+ )
+ else:
+ end = kv_cache["end"]
+ new_end = end + s
+
+ # Roped K and plain V go into cache
+ kv_cache["k"][:, end:new_end] = k
+ kv_cache["v"][:, end:new_end] = v
+ kv_cache["end"] = new_end
+
+ x = optimized_attention(
+ q.view(b, s, n * d),
+ kv_cache["k"][:, :new_end].view(b, new_end, n * d),
+ kv_cache["v"][:, :new_end].view(b, new_end, n * d),
+ heads=self.num_heads,
+ transformer_options=transformer_options,
+ )
+
+ x = self.o(x)
+ return x
+
+
+class CausalWanAttentionBlock(WanAttentionBlock):
+ """Transformer block with KV-cached self-attention and cross-attention caching."""
+
+ def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
+ window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
+ eps=1e-6, operation_settings={}):
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
+ window_size, qk_norm, cross_attn_norm, eps,
+ operation_settings=operation_settings)
+ self.self_attn = CausalWanSelfAttention(
+ dim, num_heads, window_size, qk_norm, eps,
+ operation_settings=operation_settings)
+
+ def forward(self, x, e, freqs, context, context_img_len=257,
+ kv_cache=None, crossattn_cache=None, transformer_options={}):
+ if e.ndim < 4:
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
+ else:
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
+
+ # Self-attention with optional KV cache
+ x = x.contiguous()
+ y = self.self_attn(
+ torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
+ freqs, kv_cache=kv_cache, transformer_options=transformer_options)
+ x = torch.addcmul(x, y, repeat_e(e[2], x))
+ del y
+
+ # Cross-attention with optional caching
+ if crossattn_cache is not None and crossattn_cache.get("is_init"):
+ q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
+ x_ca = optimized_attention(
+ q, crossattn_cache["k"], crossattn_cache["v"],
+ heads=self.num_heads, transformer_options=transformer_options)
+ x = x + self.cross_attn.o(x_ca)
+ else:
+ x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
+ if crossattn_cache is not None:
+ crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
+ crossattn_cache["v"] = self.cross_attn.v(context)
+ crossattn_cache["is_init"] = True
+
+ # FFN
+ y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
+ x = torch.addcmul(x, y, repeat_e(e[5], x))
+ return x
+
+
+class CausalWanModel(WanModel):
+ """
+ Wan 2.1 diffusion backbone with causal KV-cache support.
+
+ Same weight structure as WanModel -- loads identical state dicts.
+ Adds forward_block() for frame-by-frame autoregressive inference.
+ """
+
+ def __init__(self,
+ model_type='t2v',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ image_model=None,
+ device=None,
+ dtype=None,
+ operations=None):
+ super().__init__(
+ model_type=model_type, patch_size=patch_size, text_len=text_len,
+ in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
+ text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
+ num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
+ cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
+ wan_attn_block_class=CausalWanAttentionBlock,
+ device=device, dtype=dtype, operations=operations)
+
+ def forward_block(self, x, timestep, context, start_frame,
+ kv_caches, crossattn_caches, clip_fea=None):
+ """
+ Forward one temporal block for autoregressive inference.
+
+ Args:
+ x: [B, C, block_frames, H, W] input latent for the current block
+ timestep: [B, block_frames] per-frame timesteps
+ context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
+ start_frame: temporal frame index for RoPE offset
+ kv_caches: list of per-layer KV cache dicts
+ crossattn_caches: list of per-layer cross-attention cache dicts
+ clip_fea: optional CLIP features for I2V
+
+ Returns:
+ flow_pred: [B, C_out, block_frames, H, W] flow prediction
+ """
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
+ bs, c, t, h, w = x.shape
+
+ x = self.patch_embedding(x.float()).to(x.dtype)
+ grid_sizes = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2)
+
+ # Per-frame time embedding
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
+ e = e.reshape(timestep.shape[0], -1, e.shape[-1])
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
+
+ # Text embedding (reuses crossattn_cache after first block)
+ context = self.text_embedding(context)
+
+ context_img_len = None
+ if clip_fea is not None and self.img_emb is not None:
+ context_clip = self.img_emb(clip_fea)
+ context = torch.concat([context_clip, context], dim=1)
+ context_img_len = clip_fea.shape[-2]
+
+ # RoPE for current block's temporal position
+ freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
+
+ # Transformer blocks
+ for i, block in enumerate(self.blocks):
+ x = block(x, e=e0, freqs=freqs, context=context,
+ context_img_len=context_img_len,
+ kv_cache=kv_caches[i],
+ crossattn_cache=crossattn_caches[i])
+
+ # Head
+ x = self.head(x, e)
+
+ # Unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return x[:, :, :t, :h, :w]
+
+ def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
+ """Create fresh KV caches for all layers."""
+ caches = []
+ for _ in range(self.num_layers):
+ caches.append({
+ "k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
+ "v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
+ "end": 0,
+ })
+ return caches
+
+ def init_crossattn_caches(self, batch_size, device, dtype):
+ """Create fresh cross-attention caches for all layers."""
+ caches = []
+ for _ in range(self.num_layers):
+ caches.append({"is_init": False})
+ return caches
+
+ def reset_kv_caches(self, kv_caches):
+ """Reset KV caches to empty (reuse allocated memory)."""
+ for cache in kv_caches:
+ cache["end"] = 0
+
+ def reset_crossattn_caches(self, crossattn_caches):
+ """Reset cross-attention caches."""
+ for cache in crossattn_caches:
+ cache["is_init"] = False
+
+ @property
+ def head_dim(self):
+ return self.dim // self.num_heads
+
+ def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
+ ar_state = transformer_options.get("ar_state")
+ if ar_state is not None:
+ bs = x.shape[0]
+ block_frames = x.shape[2]
+ t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
+ return self.forward_block(
+ x=x, timestep=t_per_frame, context=context,
+ start_frame=ar_state["start_frame"],
+ kv_caches=ar_state["kv_caches"],
+ crossattn_caches=ar_state["crossattn_caches"],
+ clip_fea=clip_fea,
+ )
+
+ return super().forward(x, timestep, context, clip_fea=clip_fea,
+ time_dim_concat=time_dim_concat,
+ transformer_options=transformer_options, **kwargs)
diff --git a/comfy/lora.py b/comfy/lora.py
index e4337c729..db8f16bcb 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -17,6 +17,7 @@
"""
from __future__ import annotations
+import comfy.memory_management
import comfy.utils
import comfy.model_management
import comfy.model_base
@@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
weight = old_weight
return weight
+
+def prefetch_prepared_value(value, allocate_buffer, stream):
+ if isinstance(value, torch.Tensor):
+ dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
+ comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
+ return comfy.memory_management.interpret_gathered_like([value], dest)[0]
+ elif isinstance(value, weight_adapter.WeightAdapterBase):
+ return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
+ elif isinstance(value, tuple):
+ return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
+ elif isinstance(value, list):
+ return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
+
+ return value
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 50dab5782..57a1e44d2 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.wan.model_animate
+import comfy.ldm.wan.ar_model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
@@ -214,6 +215,11 @@ class BaseModel(torch.nn.Module):
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
+ transformer_options = transformer_options.copy()
+ transformer_options["prefetch_dynamic_vbars"] = (
+ self.current_patcher is not None and self.current_patcher.is_dynamic()
+ )
+
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)
@@ -1360,6 +1366,13 @@ class WAN21(BaseModel):
return out
+class WAN21_CausalAR(WAN21):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super(WAN21, self).__init__(model_config, model_type, device=device,
+ unet_model=comfy.ldm.wan.ar_model.CausalWanModel)
+ self.image_to_video = False
+
+
class WAN21_Vace(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 95af40012..21738a4c7 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
+import comfy_aimdo.vram_buffer
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@@ -112,10 +113,6 @@ if args.directml is not None:
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
-try:
- import intel_extension_for_pytorch as ipex # noqa: F401
-except:
- pass
try:
_ = torch.xpu.device_count()
@@ -583,9 +580,6 @@ class LoadedModel:
real_model = self.model.model
- if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
- with torch.no_grad():
- real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
@@ -727,13 +721,15 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
- models_temp = set()
+ # Order-preserving dedup. A plain set() would randomize iteration order across runs
+ models_temp = {}
for m in models:
- models_temp.add(m)
+ models_temp[m] = None
for mm in m.model_patches_models():
- models_temp.add(mm)
+ models_temp[mm] = None
- models = models_temp
+ models = list(models_temp)
+ models.reverse()
models_to_load = []
@@ -1182,6 +1178,10 @@ stream_counters = {}
STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
+STREAM_AIMDO_CAST_BUFFERS = {}
+LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
+
+DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT
@@ -1215,13 +1215,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
return cast_buffer
+def get_aimdo_cast_buffer(offload_stream, device):
+ cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
+ if cast_buffer is None:
+ cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
+ STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
+
+ return cast_buffer
def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT
+ global LARGEST_AIMDO_CASTED_WEIGHT
+
LARGEST_CASTED_WEIGHT = (None, 0)
- for offload_stream in STREAM_CAST_BUFFERS:
- offload_stream.synchronize()
+ LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
+ for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
+ if offload_stream is not None:
+ offload_stream.synchronize()
synchronize()
+
STREAM_CAST_BUFFERS.clear()
+ STREAM_AIMDO_CAST_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):
@@ -1581,10 +1594,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
- if torch_version_numeric < (2, 3):
- return True
- else:
- return torch.xpu.get_device_properties(device).has_fp16
+ return torch.xpu.get_device_properties(device).has_fp16
if is_ascend_npu():
return True
@@ -1650,10 +1660,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
- if torch_version_numeric < (2, 3):
- return True
- else:
- return torch.xpu.is_bf16_supported()
+ return torch.xpu.is_bf16_supported()
if is_ascend_npu():
return True
@@ -1784,6 +1791,7 @@ def soft_empty_cache(force=False):
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()
elif is_intel_xpu():
+ torch.xpu.synchronize()
torch.xpu.empty_cache()
elif is_ascend_npu():
torch.npu.empty_cache()
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index e259aed63..7d2d6883f 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -121,9 +121,20 @@ class LowVramPatch:
self.patches = patches
self.convert_func = convert_func # TODO: remove
self.set_func = set_func
+ self.prepared_patches = None
+
+ def prepare(self, allocate_buffer, stream):
+ self.prepared_patches = [
+ (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
+ for patch in self.patches[self.key]
+ ]
+
+ def clear_prepared(self):
+ self.prepared_patches = None
def __call__(self, weight):
- return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
+ patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
+ return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py
new file mode 100644
index 000000000..72e11dec6
--- /dev/null
+++ b/comfy/model_prefetch.py
@@ -0,0 +1,66 @@
+import comfy_aimdo.model_vbar
+import comfy.model_management
+import comfy.ops
+
+PREFETCH_QUEUES = []
+
+def cleanup_prefetched_modules(comfy_modules):
+ for s in comfy_modules:
+ prefetch = getattr(s, "_prefetch", None)
+ if prefetch is None:
+ continue
+ for param_key in ("weight", "bias"):
+ lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
+ if lowvram_fn is not None:
+ lowvram_fn.clear_prepared()
+ if prefetch["signature"] is not None:
+ comfy_aimdo.model_vbar.vbar_unpin(s._v)
+ delattr(s, "_prefetch")
+
+def cleanup_prefetch_queues():
+ global PREFETCH_QUEUES
+
+ for queue in PREFETCH_QUEUES:
+ for entry in queue:
+ if entry is None or not isinstance(entry, tuple):
+ continue
+ _, prefetch_state = entry
+ comfy_modules = prefetch_state[1]
+ if comfy_modules is not None:
+ cleanup_prefetched_modules(comfy_modules)
+ PREFETCH_QUEUES = []
+
+def prefetch_queue_pop(queue, device, module):
+ if queue is None:
+ return
+
+ consumed = queue.pop(0)
+ if consumed is not None:
+ offload_stream, prefetch_state = consumed
+ if offload_stream is not None:
+ offload_stream.wait_stream(comfy.model_management.current_stream(device))
+ _, comfy_modules = prefetch_state
+ if comfy_modules is not None:
+ cleanup_prefetched_modules(comfy_modules)
+
+ prefetch = queue[0]
+ if prefetch is not None:
+ comfy_modules = []
+ for s in prefetch.modules():
+ if hasattr(s, "_v"):
+ comfy_modules.append(s)
+
+ offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
+ comfy.model_management.sync_stream(device, offload_stream)
+ queue[0] = (offload_stream, (prefetch, comfy_modules))
+
+def make_prefetch_queue(queue, device, transformer_options):
+ if (not transformer_options.get("prefetch_dynamic_vbars", False)
+ or comfy.model_management.NUM_STREAMS == 0
+ or comfy.model_management.is_device_cpu(device)
+ or not comfy.model_management.device_supports_non_blocking(device)):
+ return None
+
+ queue = [None] + queue + [None]
+ PREFETCH_QUEUES.append(queue)
+ return queue
diff --git a/comfy/ops.py b/comfy/ops.py
index 050f7cda0..585c185a3 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys):
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
-def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
- #vbar doesn't support CPU weights, but some custom nodes have weird paths
- #that might switch the layer to the CPU and expect it to work. We have to take
- #a clone conservatively as we are mmapped and some SFT files are packed misaligned
- #If you are a custom node author reading this, please move your layer to the GPU
- #or declare your ModelPatcher as CPU in the first place.
- if comfy.model_management.is_device_cpu(device):
- materialize_meta_param(s, ["weight", "bias"])
- weight = s.weight.to(dtype=dtype, copy=True)
- if isinstance(weight, QuantizedTensor):
- weight = weight.dequantize()
- bias = None
- if s.bias is not None:
- bias = s.bias.to(dtype=bias_dtype, copy=True)
- return weight, bias, (None, None, None)
-
+# FIXME: add n=1 cache hit fast path
+def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
offload_stream = None
- xfer_dest = None
+ cast_buffer = None
+ cast_buffer_offset = 0
+
+ def ensure_offload_stream(module, required_size, check_largest):
+ nonlocal offload_stream
+ nonlocal cast_buffer
+
+ if offload_stream is None:
+ offload_stream = comfy.model_management.get_offload_stream(device)
+ if offload_stream is None or not check_largest or len(comfy_modules) != 1:
+ return
+
+ current_size = 0 if cast_buffer is None else cast_buffer.size()
+ if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
+ offload_stream = comfy.model_management.get_offload_stream(device)
+ cast_buffer = None
+ if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
+ comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
+
+ def get_cast_buffer(buffer_size):
+ nonlocal offload_stream
+ nonlocal cast_buffer
+ nonlocal cast_buffer_offset
+
+ if buffer_size == 0:
+ return None
+
+ if offload_stream is None:
+ return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
+
+ cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
+ buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
+ cast_buffer_offset += buffer_size
+ return buffer
+
+ for s in comfy_modules:
+ signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
+ resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
+ prefetch = {
+ "signature": signature,
+ "resident": resident,
+ }
- signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
- resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
- if signature is not None:
if resident:
- weight = s._v_weight
- bias = s._v_bias
- else:
- xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
+ s._prefetch = prefetch
+ continue
- if not resident:
materialize_meta_param(s, ["weight", "bias"])
+ xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
+ needs_cast = False
xfer_source = [ s.weight, s.bias ]
@@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if data is None:
continue
if data.dtype != geometry.dtype:
+ needs_cast = True
cast_dest = xfer_dest
- if cast_dest is None:
- cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None
break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
- offload_stream = comfy.model_management.get_offload_stream(device)
- if xfer_dest is None and offload_stream is not None:
- xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
- if xfer_dest is None:
- offload_stream = comfy.model_management.get_offload_stream(device)
- xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
+ ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
if xfer_dest is None:
- xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
- offload_stream = None
+ xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s)
@@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
- comfy.model_management.sync_stream(device, offload_stream)
- if cast_dest is not None:
+ for param_key in ("weight", "bias"):
+ lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
+ if lowvram_fn is not None:
+ ensure_offload_stream(s, cast_buffer_offset, False)
+ lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
+
+ prefetch["xfer_dest"] = xfer_dest
+ prefetch["cast_dest"] = cast_dest
+ prefetch["cast_geometry"] = cast_geometry
+ prefetch["needs_cast"] = needs_cast
+ s._prefetch = prefetch
+
+ return offload_stream
+
+
+def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
+
+ prefetch = getattr(s, "_prefetch", None)
+
+ if prefetch["resident"]:
+ weight = s._v_weight
+ bias = s._v_bias
+ else:
+ xfer_dest = prefetch["xfer_dest"]
+ if prefetch["needs_cast"]:
+ cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
- comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
+ comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
if post_cast is not None:
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
- params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
+ params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
weight = params[0]
bias = params[1]
- if signature is not None:
+ if prefetch["signature"] is not None:
s._v_weight = weight
s._v_bias = bias
- s._v_signature=signature
+ s._v_signature = prefetch["signature"]
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", [])
+ if x is None:
+ return None
+
orig = x
def to_dequant(tensor, dtype):
@@ -205,14 +248,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = f(x)
return x
- update_weight = signature is not None
+ update_weight = prefetch["signature"] is not None
+ weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
+ if bias is not None:
+ bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
- weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
- if s.bias is not None:
- bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
+ if prefetch["signature"] is not None:
+ prefetch["resident"] = True
- #FIXME: weird offload return protocol
- return weight, bias, (offload_stream, device if signature is not None else None, None)
+ return weight, bias
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
@@ -230,10 +274,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
+ def format_return(result, offloadable):
+ weight, bias, offload_stream = result
+ return (weight, bias, offload_stream) if offloadable else (weight, bias)
+
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
- return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
+
+ #vbar doesn't support CPU weights, but some custom nodes have weird paths
+ #that might switch the layer to the CPU and expect it to work. We have to take
+ #a clone conservatively as we are mmapped and some SFT files are packed misaligned
+ #If you are a custom node author reading this, please move your layer to the GPU
+ #or declare your ModelPatcher as CPU in the first place.
+ if comfy.model_management.is_device_cpu(device):
+ materialize_meta_param(s, ["weight", "bias"])
+ weight = s.weight.to(dtype=dtype, copy=True)
+ if isinstance(weight, QuantizedTensor):
+ weight = weight.dequantize()
+ bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
+ return format_return((weight, bias, (None, None, None)), offloadable)
+
+ prefetched = hasattr(s, "_prefetch")
+ offload_stream = None
+ offload_device = None
+ if not prefetched:
+ offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
+ comfy.model_management.sync_stream(device, offload_stream)
+
+ weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
+
+ if not prefetched:
+ if getattr(s, "_prefetch")["signature"] is not None:
+ offload_device = device
+ for param_key in ("weight", "bias"):
+ lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
+ if lowvram_fn is not None:
+ lowvram_fn.clear_prepared()
+ delattr(s, "_prefetch")
+ return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
+
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
@@ -280,11 +360,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.weight_function:
weight = f(weight)
- if offloadable:
- return weight, bias, (offload_stream, weight_a, bias_a)
- else:
- #Legacy function signature
- return weight, bias
+ return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
def uncast_bias_weight(s, weight, bias, offload_stream):
@@ -1173,6 +1249,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self._buffers[key] = fn(buf)
return self
+ class Embedding(manual_cast.Embedding):
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys, error_msgs):
+ weight_key = f"{prefix}weight"
+ layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
+ if layer_conf is not None:
+ layer_conf = json.loads(layer_conf.numpy().tobytes())
+
+ # Only fp8 makes sense for embeddings (per-row dequant via index select).
+ # Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
+ quant_format = layer_conf.get("format", None) if layer_conf is not None else None
+ if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
+ self.quant_format = quant_format
+ qconfig = QUANT_ALGOS[quant_format]
+ layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
+ weight = state_dict.pop(weight_key)
+ manually_loaded_keys = [weight_key]
+
+ scale_key = f"{prefix}weight_scale"
+ scale = state_dict.pop(scale_key, None)
+ if scale is not None:
+ scale = scale.float()
+ manually_loaded_keys.append(scale_key)
+
+ params = layout_cls.Params(
+ scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
+ orig_dtype=MixedPrecisionOps._compute_dtype,
+ orig_shape=(self.num_embeddings, self.embedding_dim),
+ )
+ self.weight = torch.nn.Parameter(
+ QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
+ requires_grad=False)
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+ for k in manually_loaded_keys:
+ if k in missing_keys:
+ missing_keys.remove(k)
+ else:
+ if layer_conf is not None:
+ state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+
+ def state_dict(self, *args, destination=None, prefix="", **kwargs):
+ if destination is not None:
+ sd = destination
+ else:
+ sd = {}
+
+ if not hasattr(self, 'weight') or self.weight is None:
+ return sd
+
+ if isinstance(self.weight, QuantizedTensor):
+ sd_out = self.weight.state_dict("{}weight".format(prefix))
+ for k in sd_out:
+ sd[k] = sd_out[k]
+
+ quant_conf = {"format": self.quant_format}
+ sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
+ else:
+ sd["{}weight".format(prefix)] = self.weight
+ return sd
+
+ def forward_comfy_cast_weights(self, input, out_dtype=None):
+ weight = self.weight
+
+ # Optimized path: lookup in fp8, dequantize only the selected rows.
+ if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
+ qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
+ if isinstance(qdata, QuantizedTensor):
+ scale = qdata._params.scale
+ qdata = qdata._qdata
+ else:
+ scale = None
+
+ x = torch.nn.functional.embedding(
+ input, qdata, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
+ uncast_bias_weight(self, qdata, None, offload_stream)
+ target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
+ x = x.to(dtype=target_dtype)
+ if scale is not None and scale != 1.0:
+ x = x * scale.to(dtype=target_dtype)
+ return x
+
+ # Fallback for non-quantized or weight_function (LoRA) case
+ return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
+
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py
index 42ee08fb2..b90bcfd25 100644
--- a/comfy/quant_ops.py
+++ b/comfy/quant_ops.py
@@ -1,6 +1,8 @@
import torch
import logging
+from comfy.cli_args import args
+
try:
import comfy_kitchen as ck
from comfy_kitchen.tensor import (
@@ -21,7 +23,15 @@ try:
ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
- ck.registry.disable("triton")
+ if args.enable_triton_backend:
+ try:
+ import triton
+ logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
+ except ImportError as e:
+ logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
+ ck.registry.disable("triton")
+ else:
+ ck.registry.disable("triton")
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")
except ImportError as e:
diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py
index ab7cf14fa..e54be98d6 100644
--- a/comfy/rmsnorm.py
+++ b/comfy/rmsnorm.py
@@ -3,6 +3,7 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm
+# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
def rms_norm(x, weight=None, eps=1e-6):
if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py
index bbba09e26..3782fd2d5 100644
--- a/comfy/sampler_helpers.py
+++ b/comfy/sampler_helpers.py
@@ -89,7 +89,8 @@ def get_additional_models(conds, dtype):
gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
- control_nets = set(cnets)
+ # Order-preserving dedup. A plain set() would randomize iteration order across runs
+ control_nets = list(dict.fromkeys(cnets))
inference_memory = 0
control_models = []
diff --git a/comfy/sd.py b/comfy/sd.py
index ee66490f5..749bdd710 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -65,6 +65,8 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.text_encoders.ernie
+import comfy.text_encoders.gemma4
+import comfy.text_encoders.cogvideo
import comfy.model_patcher
import comfy.lora
@@ -1223,6 +1225,7 @@ class CLIPType(Enum):
NEWBIE = 24
FLUX2 = 25
LONGCAT_IMAGE = 26
+ COGVIDEOX = 27
@@ -1271,6 +1274,9 @@ class TEModel(Enum):
QWEN35_9B = 26
QWEN35_27B = 27
MINISTRAL_3_3B = 28
+ GEMMA_4_E4B = 29
+ GEMMA_4_E2B = 30
+ GEMMA_4_31B = 31
def detect_te_model(sd):
@@ -1296,6 +1302,12 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
+ if 'model.layers.59.self_attn.q_norm.weight' in sd:
+ return TEModel.GEMMA_4_31B
+ if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
+ return TEModel.GEMMA_4_E4B
+ if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
+ return TEModel.GEMMA_4_E2B
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@@ -1418,6 +1430,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
+ elif clip_type == CLIPType.COGVIDEOX:
+ clip_target.clip = comfy.text_encoders.cogvideo.cogvideo_te(**t5xxl_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.cogvideo.CogVideoXTokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
@@ -1435,6 +1450,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
+ elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
+ variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
+ TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
+ TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
+ clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
+ clip_target.tokenizer = variant.tokenizer
+ tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 92d0305c5..6a9613602 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -1167,6 +1167,25 @@ class WAN21_T2V(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
+class WAN21_CausalAR_T2V(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "t2v",
+ "causal_ar": True,
+ }
+
+ sampling_settings = {
+ "shift": 5.0,
+ }
+
+ def __init__(self, unet_config):
+ super().__init__(unet_config)
+ self.unet_config.pop("causal_ar", None)
+
+ def get_model(self, state_dict, prefix="", device=None):
+ return model_base.WAN21_CausalAR(self, device=device)
+
+
class WAN21_I2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@@ -1853,6 +1872,14 @@ class CogVideoX_T2V(supported_models_base.BASE):
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
+ def __init__(self, unet_config):
+ # 2b-class (dim=1920, heads=30) uses scale_factor=1.15258426.
+ # 5b-class (dim=3072, heads=48) — incl. CogVideoX-5b, 1.5-5B, and
+ # Fun-V1.5 inpainting — uses scale_factor=0.7 per vae/config.json.
+ if unet_config.get("num_attention_heads", 0) >= 48:
+ self.latent_format = latent_formats.CogVideoX1_5
+ super().__init__(unet_config)
+
def get_model(self, state_dict, prefix="", device=None):
# CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE
if self.unet_config.get("patch_size_t") is not None:
@@ -1879,6 +1906,102 @@ class CogVideoX_I2V(CogVideoX_T2V):
out = model_base.CogVideoX(self, image_to_video=True, device=device)
return out
-models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31, CogVideoX_I2V, CogVideoX_T2V]
+class CogVideoX_Inpaint(CogVideoX_T2V):
+ unet_config = {
+ "image_model": "cogvideox",
+ "in_channels": 48,
+ }
-models += [SVD_img2vid]
+ def get_model(self, state_dict, prefix="", device=None):
+ if self.unet_config.get("patch_size_t") is not None:
+ self.unet_config.setdefault("sample_height", 96)
+ self.unet_config.setdefault("sample_width", 170)
+ self.unet_config.setdefault("sample_frames", 81)
+ out = model_base.CogVideoX(self, image_to_video=True, device=device)
+ return out
+
+
+models = [
+ LotusD,
+ Stable_Zero123,
+ SD15_instructpix2pix,
+ SD15,
+ SD20,
+ SD21UnclipL,
+ SD21UnclipH,
+ SDXL_instructpix2pix,
+ SDXLRefiner,
+ SDXL,
+ SSD1B,
+ KOALA_700M,
+ KOALA_1B,
+ Segmind_Vega,
+ SD_X4Upscaler,
+ Stable_Cascade_C,
+ Stable_Cascade_B,
+ SV3D_u,
+ SV3D_p,
+ SD3,
+ StableAudio,
+ AuraFlow,
+ PixArtAlpha,
+ PixArtSigma,
+ HunyuanDiT,
+ HunyuanDiT1,
+ FluxInpaint,
+ Flux,
+ LongCatImage,
+ FluxSchnell,
+ GenmoMochi,
+ LTXV,
+ LTXAV,
+ HunyuanVideo15_SR_Distilled,
+ HunyuanVideo15,
+ HunyuanImage21Refiner,
+ HunyuanImage21,
+ HunyuanVideoSkyreelsI2V,
+ HunyuanVideoI2V,
+ HunyuanVideo,
+ CosmosT2V,
+ CosmosI2V,
+ CosmosT2IPredict2,
+ CosmosI2VPredict2,
+ ZImagePixelSpace,
+ ZImage,
+ Lumina2,
+ WAN22_T2V,
+ WAN21_CausalAR_T2V,
+ WAN21_T2V,
+ WAN21_I2V,
+ WAN21_FunControl2V,
+ WAN21_Vace,
+ WAN21_Camera,
+ WAN22_Camera,
+ WAN22_S2V,
+ WAN21_HuMo,
+ WAN22_Animate,
+ WAN21_FlowRVS,
+ WAN21_SCAIL,
+ Hunyuan3Dv2mini,
+ Hunyuan3Dv2,
+ Hunyuan3Dv2_1,
+ HiDream,
+ Chroma,
+ ChromaRadiance,
+ ACEStep,
+ ACEStep15,
+ Omnigen2,
+ QwenImage,
+ Flux2,
+ Kandinsky5Image,
+ Kandinsky5,
+ Anima,
+ RT_DETR_v4,
+ ErnieImage,
+ SAM3,
+ SAM31,
+ CogVideoX_Inpaint,
+ CogVideoX_I2V,
+ CogVideoX_T2V,
+ SVD_img2vid,
+]
diff --git a/comfy/text_encoders/cogvideo.py b/comfy/text_encoders/cogvideo.py
index f1e8e3f5d..b97310709 100644
--- a/comfy/text_encoders/cogvideo.py
+++ b/comfy/text_encoders/cogvideo.py
@@ -1,6 +1,48 @@
import comfy.text_encoders.sd3_clip
+from comfy import sd1_clip
class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer):
+ """Inner T5 tokenizer for CogVideoX.
+
+ CogVideoX was trained with T5 embeddings padded to 226 tokens (not 77 like SD3).
+ Used both directly by supported_models.CogVideoX_T2V.clip_target (paired with
+ the raw T5XXLModel) and by the CogVideoXTokenizer outer wrapper below.
+ """
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226)
+
+
+class CogVideoXTokenizer(sd1_clip.SD1Tokenizer):
+ """Outer tokenizer wrapper for CLIPLoader (type="cogvideox")."""
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
+ clip_name="t5xxl", tokenizer=CogVideoXT5Tokenizer)
+
+
+class CogVideoXT5XXL(sd1_clip.SD1ClipModel):
+ """Outer T5XXL model wrapper for CLIPLoader (type="cogvideox").
+
+ Wraps the raw T5XXL model in the SD1ClipModel interface so that CLIP.__init__
+ (which reads self.dtypes) works correctly. The inner model is the standard
+ sd3_clip.T5XXLModel (no attention_mask change needed for CogVideoX).
+ """
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ super().__init__(device=device, dtype=dtype, name="t5xxl",
+ clip_model=comfy.text_encoders.sd3_clip.T5XXLModel,
+ model_options=model_options)
+
+
+def cogvideo_te(dtype_t5=None, t5_quantization_metadata=None):
+ """Factory that returns a CogVideoXT5XXL class configured with the detected
+ T5 dtype and optional quantization metadata, for use in load_text_encoder_state_dicts.
+ """
+ class CogVideoXTEModel_(CogVideoXT5XXL):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if t5_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
+ if dtype_t5 is not None:
+ dtype = dtype_t5
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
+ return CogVideoXTEModel_
diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py
new file mode 100644
index 000000000..f050061ed
--- /dev/null
+++ b/comfy/text_encoders/gemma4.py
@@ -0,0 +1,1298 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from dataclasses import dataclass
+import math
+
+from comfy import sd1_clip
+import comfy.model_management
+from comfy.ldm.modules.attention import optimized_attention_for_device
+from comfy.rmsnorm import rms_norm
+from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding
+
+
+# Intentional minor divergences from transformers -reference implementation:
+# - Embedding sqrt(hidden_size) scale applied as a Python scalar (full precision) instead of dtype-matched buffer tensor.
+# - RMSNorm uses torch fused F.rms_norm, very slight numerical differences, but considerably faster
+# - Input image and audio resizing/resampling slightly different numerically
+
+
+GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
+GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
+GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5}
+
+@dataclass
+class Gemma4Config:
+ vocab_size: int = 262144
+ hidden_size: int = 2560
+ intermediate_size: int = 10240
+ num_hidden_layers: int = 42
+ num_attention_heads: int = 8
+ num_key_value_heads: int = 2
+ max_position_embeddings: int = 131072
+ rms_norm_eps: float = 1e-6
+ rope_theta = [1000000.0, 10000.0]
+ transformer_type: str = "gemma4"
+ head_dim = 256
+ global_head_dim = 512
+ rms_norm_add = False
+ mlp_activation = "gelu_pytorch_tanh"
+ qkv_bias = False
+ rope_dims = None
+ q_norm = "gemma3"
+ k_norm = "gemma3"
+ sliding_attention = [512, 512, 512, 512, 512, False]
+ rope_scale = None
+ partial_rotary_factor: float = 0.25
+ final_norm: bool = True
+ lm_head: bool = False
+ final_logit_softcapping: float = 30.0
+ hidden_size_per_layer_input: int = 256
+ num_kv_shared_layers: int = 18
+ use_double_wide_mlp: bool = False
+ stop_tokens = [1, 50, 106]
+ vision_config = GEMMA4_VISION_CONFIG
+ audio_config = GEMMA4_AUDIO_CONFIG
+ mm_tokens_per_image = 280
+
+@dataclass
+class Gemma4_E2B_Config(Gemma4Config):
+ hidden_size: int = 1536
+ intermediate_size: int = 6144
+ num_hidden_layers: int = 35
+ num_key_value_heads: int = 1
+ sliding_attention = [512, 512, 512, 512, False]
+ num_kv_shared_layers: int = 20
+ use_double_wide_mlp: bool = True
+
+@dataclass
+class Gemma4_31B_Config(Gemma4Config):
+ hidden_size: int = 5376
+ intermediate_size: int = 21504
+ num_hidden_layers: int = 60
+ num_attention_heads: int = 32
+ num_key_value_heads: int = 16
+ sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
+ hidden_size_per_layer_input: int = 0
+ num_kv_shared_layers: int = 0
+ audio_config = None
+ vision_config = GEMMA4_VISION_31B_CONFIG
+
+
+# unfused RoPE as addcmul_ RoPE diverges from reference code
+def _apply_rotary_pos_emb(x, freqs_cis):
+ cos, sin = freqs_cis[0], freqs_cis[1]
+ half = x.shape[-1] // 2
+ out = x * cos
+ out[..., :half] -= x[..., half:] * sin[..., :half]
+ out[..., half:] += x[..., :half] * sin[..., half:]
+ return out
+
+class Gemma4Attention(nn.Module):
+ def __init__(self, config, head_dim, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.num_kv_heads = config.num_key_value_heads
+ self.hidden_size = config.hidden_size
+ self.head_dim = head_dim
+ self.inner_size = self.num_heads * head_dim
+
+ self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
+ self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
+ self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
+ self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
+
+ self.q_norm = None
+ self.k_norm = None
+ if config.q_norm == "gemma3":
+ self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ if config.k_norm == "gemma3":
+ self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask=None,
+ freqs_cis=None,
+ past_key_value=None,
+ sliding_window=None,
+ shared_kv=None,
+ ):
+ batch_size, seq_length, _ = hidden_states.shape
+
+ xq = self.q_proj(hidden_states)
+ xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ if self.q_norm is not None:
+ xq = self.q_norm(xq)
+
+ if shared_kv is not None:
+ xk, xv = shared_kv
+ # Apply RoPE to Q only (K already has RoPE from source layer)
+ xq = _apply_rotary_pos_emb(xq, freqs_cis)
+ present_key_value = None
+ shareable_kv = None
+ else:
+ xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
+ xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
+ if self.k_norm is not None:
+ xk = self.k_norm(xk)
+ xv = rms_norm(xv)
+ xk = xk.transpose(1, 2)
+ xv = xv.transpose(1, 2)
+ xq = _apply_rotary_pos_emb(xq, freqs_cis)
+ xk = _apply_rotary_pos_emb(xk, freqs_cis)
+
+ present_key_value = None
+ if past_key_value is not None:
+ cumulative_len = 0
+ if len(past_key_value) > 0:
+ past_key, past_value, cumulative_len = past_key_value
+ xk = torch.cat((past_key, xk), dim=2)
+ xv = torch.cat((past_value, xv), dim=2)
+ new_cumulative = cumulative_len + seq_length
+ if sliding_window is not None and xk.shape[2] > sliding_window - 1:
+ cache_k = xk[:, :, -(sliding_window - 1):]
+ cache_v = xv[:, :, -(sliding_window - 1):]
+ else:
+ cache_k = xk
+ cache_v = xv
+ present_key_value = (cache_k, cache_v, new_cumulative)
+
+ # KV for sharing: full xk/xv that SDPA sees (not evicted cache)
+ shareable_kv = (xk, xv)
+
+ # GQA: pass unexpanded KV with enable_gqa when no sliding mask,
+ # expand heads when sliding mask is present
+ # has to be done within SDPA itself to match the reference code, pre-scaling expansion causes numerical differences
+ expand_kv = (self.num_heads != self.num_kv_heads and
+ sliding_window is not None and
+ xk.shape[2] >= sliding_window)
+ if expand_kv:
+ xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
+ xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
+ gqa_kwargs = {} if expand_kv else ({"enable_gqa": True} if self.num_heads != self.num_kv_heads else {})
+ output = optimized_attention_for_device(xq.device, mask=attention_mask is not None, small_input=True)(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, scale=1.0, **gqa_kwargs)
+
+ return self.o_proj(output), present_key_value, shareable_kv
+
+
+class TransformerBlockGemma4(nn.Module):
+ def __init__(self, config, index, device=None, dtype=None, ops=None):
+ super().__init__()
+ if config.sliding_attention is not None:
+ self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
+ else:
+ self.sliding_attention = False
+
+ head_dim = config.head_dim if self.sliding_attention else config.global_head_dim
+
+ self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops)
+
+ num_kv_shared = config.num_kv_shared_layers
+ first_kv_shared = config.num_hidden_layers - num_kv_shared
+ mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None
+ self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size)
+
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
+ if self.hidden_size_per_layer_input:
+ self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype)
+ self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype)
+ self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
+ else:
+ self.layer_scalar = None
+
+ def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None):
+ sliding_window = None
+ if self.sliding_attention:
+ sliding_window = self.sliding_attention
+ # For prefill > sliding window, add sliding window restriction to the causal mask.
+ if x.shape[1] > self.sliding_attention:
+ sw_mask = torch.zeros(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device)
+ sw_mask.masked_fill_(torch.ones_like(sw_mask, dtype=torch.bool).tril_(-self.sliding_attention), torch.finfo(x.dtype).min)
+ attention_mask = attention_mask + sw_mask if attention_mask is not None else sw_mask
+ freqs_cis = freqs_cis[1]
+ else:
+ freqs_cis = freqs_cis[0]
+
+ residual = x
+ x = self.input_layernorm(x)
+ x, present_key_value, shareable_kv = self.self_attn(
+ hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis,
+ past_key_value=past_key_value, sliding_window=sliding_window, shared_kv=shared_kv,
+ )
+ x = self.post_attention_layernorm(x)
+ x = residual + x
+
+ residual = x
+ x = self.pre_feedforward_layernorm(x)
+ x = self.mlp(x)
+ x = self.post_feedforward_layernorm(x)
+ x = residual + x
+
+ if self.hidden_size_per_layer_input and per_layer_input is not None:
+ residual = x
+ x = self.per_layer_input_gate(x)
+ x = torch.nn.functional.gelu(x, approximate="tanh")
+ x = x * per_layer_input
+ x = self.per_layer_projection(x)
+ x = self.post_per_layer_input_norm(x)
+ x = residual + x
+
+ if self.layer_scalar is not None:
+ x = x * self.layer_scalar
+
+ return x, present_key_value, shareable_kv
+
+
+class Gemma4Transformer(nn.Module):
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.config = config
+
+ self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
+
+ self.layers = nn.ModuleList([
+ TransformerBlockGemma4(config, index=i, device=device, dtype=dtype, ops=ops)
+ for i in range(config.num_hidden_layers)
+ ])
+
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) if config.final_norm else None
+
+ # Precompute RoPE inv_freq on CPU to match reference code's exact value
+ rope_angles_global = int(config.partial_rotary_factor * config.global_head_dim // 2)
+ nope_global = config.global_head_dim // 2 - rope_angles_global
+ global_inv = 1.0 / (config.rope_theta[0] ** (torch.arange(0, 2 * rope_angles_global, 2).float() / config.global_head_dim))
+ if nope_global > 0:
+ global_inv = torch.cat([global_inv, torch.zeros(nope_global)])
+ self.register_buffer("_global_inv_freq", global_inv, persistent=False)
+
+ sliding_inv = 1.0 / (config.rope_theta[1] ** (torch.arange(0, config.head_dim, 2).float() / config.head_dim))
+ self.register_buffer("_sliding_inv_freq", sliding_inv, persistent=False)
+
+ # Per-layer input mechanism
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
+ if self.hidden_size_per_layer_input:
+ self.embed_tokens_per_layer = _make_scaled_embedding(ops, config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, self.hidden_size_per_layer_input ** 0.5, device, dtype)
+ self.per_layer_model_projection = ops.Linear(
+ config.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input,
+ bias=False, device=device, dtype=dtype)
+ self.per_layer_projection_norm = RMSNorm(
+ self.hidden_size_per_layer_input, eps=config.rms_norm_eps,
+ device=device, dtype=dtype)
+
+ def get_past_len(self, past_key_values):
+ for kv in past_key_values:
+ if len(kv) >= 3:
+ return kv[2]
+ return 0
+
+ def _freqs_from_inv(self, inv_freq, position_ids, device, dtype):
+ """Compute cos/sin from stored inv_freq"""
+ inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(device)
+ pos_exp = position_ids[:, None, :].float()
+ freqs = (inv_exp @ pos_exp).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ return emb.cos().unsqueeze(1).to(dtype), emb.sin().unsqueeze(1).to(dtype)
+
+ def compute_freqs_cis(self, position_ids, device, dtype=None):
+ global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, device, dtype)
+ sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, device, dtype)
+ return [global_freqs, sliding_freqs]
+
+ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None,
+ final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=None,
+ past_key_values=None, input_ids=None):
+ if embeds is not None:
+ x = embeds
+ else:
+ x = self.embed_tokens(x, out_dtype=dtype)
+
+ seq_len = x.shape[1]
+ past_len = 0
+ if past_key_values is not None and len(past_key_values) > 0:
+ past_len = self.get_past_len(past_key_values)
+
+ if position_ids is None:
+ position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
+
+ freqs_cis = self.compute_freqs_cis(position_ids, x.device, dtype=x.dtype)
+
+ mask = None
+ min_val = torch.finfo(x.dtype).min
+ if attention_mask is not None:
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
+ mask = mask.masked_fill(mask.to(torch.bool), min_val)
+
+ if seq_len > 1:
+ causal_mask = torch.zeros(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device)
+ causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val)
+ mask = mask + causal_mask if mask is not None else causal_mask
+
+ # Per-layer inputs
+ per_layer_inputs = None
+ if self.hidden_size_per_layer_input:
+ num_layers = self.config.num_hidden_layers
+ hpl = self.hidden_size_per_layer_input
+ per_layer_proj = self.per_layer_model_projection(x) * (1.0 / (self.config.hidden_size ** 0.5))
+ per_layer_proj = self.per_layer_projection_norm(per_layer_proj.reshape(*x.shape[:-1], num_layers, hpl))
+ if input_ids is not None and input_ids.shape[1] == x.shape[1]:
+ per_layer_emb = self.embed_tokens_per_layer(input_ids).reshape(*input_ids.shape, num_layers, hpl)
+ per_layer_inputs = (per_layer_proj + per_layer_emb) * (0.5 ** 0.5)
+ else:
+ per_layer_inputs = per_layer_proj
+
+ # KV sharing: later layers reuse KV from the last non-shared sliding/global layer
+ num_kv_shared = self.config.num_kv_shared_layers
+ first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers
+ shared_sliding_kv = None # KV from last non-shared sliding layer
+ shared_global_kv = None # KV from last non-shared global layer
+
+ intermediate = None
+ next_key_values = []
+ for i, layer in enumerate(self.layers):
+ past_kv = past_key_values[i] if past_key_values is not None and len(past_key_values) > 0 else None
+
+ layer_kwargs = {}
+ if per_layer_inputs is not None:
+ layer_kwargs['per_layer_input'] = per_layer_inputs[:, :, i, :]
+
+ is_sliding = hasattr(layer, 'sliding_attention') and layer.sliding_attention
+ if i >= first_kv_shared and num_kv_shared > 0:
+ shared = shared_sliding_kv if is_sliding else shared_global_kv
+ if shared is not None:
+ layer_kwargs['shared_kv'] = shared
+
+ x, current_kv, shareable_kv = layer(x=x, attention_mask=mask, freqs_cis=freqs_cis, past_key_value=past_kv, **layer_kwargs)
+
+ next_key_values.append(current_kv if current_kv is not None else ())
+
+ # Only track the last sliding/global before the sharing boundary
+ if i < first_kv_shared and shareable_kv is not None:
+ if is_sliding:
+ shared_sliding_kv = shareable_kv
+ else:
+ shared_global_kv = shareable_kv
+
+ if i == intermediate_output:
+ intermediate = x.clone()
+
+ if self.norm is not None:
+ x = self.norm(x)
+
+ if len(next_key_values) > 0:
+ return x, intermediate, next_key_values
+ return x, intermediate
+
+
+class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module):
+ """Common base for all Gemma4 variants: text model + vision."""
+ def _init_model(self, config, dtype, device, operations):
+ self.num_layers = config.num_hidden_layers
+ self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+ self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype=dtype, device=device, ops=operations)
+ self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype=dtype, device=device, ops=operations)
+
+ def logits(self, x):
+ logits = super().logits(x)
+ cap = self.model.config.final_logit_softcapping
+ if cap:
+ logits = cap * torch.tanh(logits / cap)
+ return logits
+
+ def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
+ past_key_values = []
+ for _ in range(self.model.config.num_hidden_layers):
+ past_key_values.append(())
+ return past_key_values
+
+ def preprocess_embed(self, embed, device):
+ if embed["type"] == "image":
+ image = embed.pop("data").movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W]
+ max_soft_tokens = embed.get("max_soft_tokens", None)
+ vision_out = self.vision_model(image.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens)
+ return self.multi_modal_projector(vision_out), None
+ return None, None
+
+
+class Gemma4AudioMixin:
+ """Adds audio support to a Gemma4 model."""
+ def _init_audio(self, config, dtype, device, operations):
+ self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype=dtype, device=device, ops=operations)
+ self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype=dtype, device=device, ops=operations)
+
+ def preprocess_embed(self, embed, device):
+ result, extra = super().preprocess_embed(embed, device)
+ if result is not None:
+ return result, extra
+ if embed["type"] == "audio":
+ audio = embed.pop("data").to(device, dtype=torch.float32)
+ audio_mask = embed.pop("mask", None)
+ if audio_mask is not None:
+ audio_mask = audio_mask.to(device)
+ audio_out = self.audio_model(audio, audio_mask=audio_mask)
+ return self.audio_projector(audio_out), None
+ return None, None
+
+
+# Vision Encoder
+
+def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None):
+ """Compute 2D RoPE for vision: separate frequencies for x and y dimensions.
+
+ Args:
+ head_dim: dimension per head (e.g. 64)
+ pixel_position_ids: [batch, num_patches, 2] with (x, y) coords
+ theta: RoPE base frequency
+ Returns:
+ (cos, sin) each of shape [batch, num_patches, head_dim]
+ """
+ rotary_dim_per_axis = head_dim // 2
+ freq_indices = torch.arange(0, rotary_dim_per_axis, 2, device=device).float()
+ inv_freq = 1.0 / (theta ** (freq_indices / rotary_dim_per_axis))
+
+ all_cos, all_sin = [], []
+ for i in range(2): # x and y
+ dim_positions = pixel_position_ids[:, :, i].float() # [batch, num_patches]
+ freqs = torch.einsum('bi,j->bij', dim_positions, inv_freq.to(device)) # [batch, num_patches, rotary_dim/2]
+ emb = torch.cat([freqs, freqs], dim=-1) # [batch, num_patches, rotary_dim]
+ all_cos.append(emb.cos())
+ all_sin.append(emb.sin())
+
+ cos = torch.cat(all_cos, dim=-1).to(pixel_position_ids.device) # [batch, num_patches, head_dim]
+ sin = torch.cat(all_sin, dim=-1).to(pixel_position_ids.device)
+ return cos, sin
+
+
+def _apply_vision_2d_rope(x, freqs):
+ """Apply 2D RoPE (multidimensional) to vision query/key states.
+
+ Splits x and cos/sin into ndim=2 parts, applies 1D RoPE to each independently.
+
+ x: [batch, heads, seq, head_dim]
+ freqs: (cos, sin) each [batch, seq, head_dim]
+ """
+ cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim]
+ sin = freqs[1].unsqueeze(1)
+ half = x.shape[-1] // 2
+ a = _apply_rotary_pos_emb(x[..., :half], (cos[..., :half], sin[..., :half]))
+ b = _apply_rotary_pos_emb(x[..., half:], (cos[..., half:], sin[..., half:]))
+ return torch.cat([a, b], dim=-1)
+
+
+class ClippedLinear(nn.Module):
+ """Linear layer with activation clipping (from quantization-aware training).
+
+ Stores input_max/min and output_max/min as buffers loaded from checkpoint.
+ """
+ def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.linear = ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
+ self.register_buffer('input_max', torch.tensor(float('inf'), device=device, dtype=dtype))
+ self.register_buffer('input_min', torch.tensor(float('-inf'), device=device, dtype=dtype))
+ self.register_buffer('output_max', torch.tensor(float('inf'), device=device, dtype=dtype))
+ self.register_buffer('output_min', torch.tensor(float('-inf'), device=device, dtype=dtype))
+
+ @property
+ def weight(self):
+ return self.linear.weight
+
+ def forward(self, x):
+ x = x.clamp(min=self.input_min, max=self.input_max)
+ x = self.linear(x)
+ return x.clamp_(min=self.output_min, max=self.output_max)
+
+
+class Gemma4VisionMLP(nn.Module):
+ """SwiGLU MLP matching gate_proj/up_proj/down_proj structure."""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ hidden_size = config["hidden_size"]
+ intermediate_size = config["intermediate_size"]
+ self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
+ self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
+ self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops)
+
+ def forward(self, x):
+ return self.down_proj(torch.nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x))
+
+
+class Gemma4VisionAttention(nn.Module):
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.hidden_size = config["hidden_size"]
+ self.num_heads = config["num_attention_heads"]
+ self.head_dim = config.get("head_dim", self.hidden_size // self.num_heads)
+
+ self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
+ self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
+ self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
+ self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops)
+
+ self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
+
+ def forward(self, x, freqs, attention_mask=None):
+ batch_size, seq_length, _ = x.shape
+
+ xq = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
+ xk = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
+ xv = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
+
+ xq = self.q_norm(xq).transpose(1, 2)
+ xk = self.k_norm(xk).transpose(1, 2)
+ xv = rms_norm(xv)
+
+ xq = _apply_vision_2d_rope(xq, freqs)
+ xk = _apply_vision_2d_rope(xk, freqs)
+
+ xv = xv.to(xq.dtype).transpose(1, 2)
+
+ output = optimized_attention_for_device(xq.device, mask=attention_mask is not None, small_input=True)(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, scale=1.0)
+ return self.o_proj(output)
+
+
+class Gemma4VisionLayer(nn.Module):
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops)
+ self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops)
+ norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ hidden = config["hidden_size"]
+ self.input_layernorm = RMSNorm(hidden, **norm_kwargs)
+ self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs)
+ self.pre_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs)
+ self.post_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs)
+
+ def forward(self, x, freqs, attention_mask=None):
+ residual = x
+ x = self.input_layernorm(x)
+ x = self.self_attn(x, freqs, attention_mask=attention_mask)
+ x = self.post_attention_layernorm(x)
+ x = residual + x
+
+ residual = x
+ x = self.pre_feedforward_layernorm(x)
+ x = self.mlp(x)
+ x = self.post_feedforward_layernorm(x)
+ x = residual + x
+ return x
+
+
+class Gemma4PatchEmbedder(nn.Module):
+ """Patch embedding with learned 2D position embeddings via one-hot lookup."""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ hidden_size = config["hidden_size"]
+ patch_size = config["patch_size"]
+ self.patch_size = patch_size
+ self.position_embedding_size = config.get("position_embedding_size", 10240)
+
+ self.input_proj = ops.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype)
+ self.position_embedding_table = nn.Parameter(
+ torch.empty(2, self.position_embedding_size, hidden_size, device=device, dtype=dtype)
+ )
+
+ def forward(self, patches, pixel_position_ids):
+ """
+ patches: [B, num_patches, 3*patch_size²] in [0,1] range (normalized to [-1,1] inside, matching HF)
+ pixel_position_ids: [B, num_patches, 2] with (x,y) positions, (-1,-1) for padding
+ """
+ hidden_states = self.input_proj((2.0 * (patches - 0.5)).to(self.input_proj.weight.dtype))
+
+ clamped_positions = pixel_position_ids.clamp(min=0)
+ pos_table = comfy.model_management.cast_to_device(self.position_embedding_table, hidden_states.device, hidden_states.dtype)
+ position_embeddings = pos_table[0][clamped_positions[..., 0]] + pos_table[1][clamped_positions[..., 1]]
+
+ # Zero out position embeddings for padding patches (matching HF)
+ padding_positions = (pixel_position_ids == -1).all(dim=-1)
+ position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
+
+ return hidden_states + position_embeddings
+
+
+class Gemma4VisionEncoderLayers(nn.Module):
+ """Wrapper to produce state dict keys as encoder.layers.X.*"""
+ def __init__(self, config, dtype=None, device=None, ops=None):
+ super().__init__()
+ self.layers = nn.ModuleList([
+ Gemma4VisionLayer(config, device=device, dtype=dtype, ops=ops)
+ for _ in range(config["num_hidden_layers"])
+ ])
+
+
+class Gemma4VisionEncoder(nn.Module):
+ def __init__(self, config, dtype=None, device=None, ops=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config["hidden_size"]
+ self.head_dim = config.get("head_dim", config["hidden_size"] // config["num_attention_heads"])
+ self.patch_size = config["patch_size"]
+ self.pooling_kernel_size = config.get("pooling_kernel_size", 3)
+ self.root_hidden_size = self.hidden_size ** 0.5
+
+ self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, ops=ops)
+ self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, ops=ops)
+
+ def forward(self, pixel_values, max_soft_tokens=None):
+ """
+ pixel_values: [B, C, H, W] in [0,1] range
+ max_soft_tokens: if provided, pad to max_soft_tokens * k² total patches
+ """
+ batch_size, _, height, width = pixel_values.shape
+ ps = self.patch_size
+ k = self.pooling_kernel_size
+ patches_h, patches_w = height // ps, width // ps
+ num_patches = patches_h * patches_w
+ output_length = max_soft_tokens if max_soft_tokens is not None else num_patches // (k * k)
+ n_padding = output_length * k * k - num_patches
+
+ # Patchify and build position grid
+ patches = pixel_values.reshape(batch_size, -1, patches_h, ps, patches_w, ps)
+ patches = patches.permute(0, 2, 4, 3, 5, 1).reshape(batch_size, num_patches, -1)
+ grid_y, grid_x = torch.meshgrid(torch.arange(patches_h, device=pixel_values.device), torch.arange(patches_w, device=pixel_values.device), indexing='ij')
+ position_ids = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1).unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Append zero-pixel padding with (-1,-1) positions
+ if n_padding > 0:
+ patches = torch.cat([patches, patches.new_zeros(batch_size, n_padding, patches.shape[-1])], dim=1)
+ position_ids = torch.cat([position_ids, position_ids.new_full((batch_size, n_padding, 2), -1)], dim=1)
+
+ padding = (position_ids == -1).all(dim=-1)
+
+ # Embed, encode, pool
+ x = self.patch_embedder(patches, position_ids)
+ freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device)
+ freqs = tuple(t.to(x.dtype) for t in freqs)
+ if n_padding > 0:
+ mask = padding.unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1)
+ mask = torch.zeros_like(mask, dtype=x.dtype).masked_fill_(mask, torch.finfo(x.dtype).min)
+ else:
+ mask = None
+
+ for layer in self.encoder.layers:
+ x = layer(x, freqs, attention_mask=mask)
+
+ if n_padding > 0:
+ x = x.masked_fill(padding.unsqueeze(-1), 0.0)
+
+ # Average pool by spatial position
+ clamped = position_ids.clamp(min=0)
+ max_x = clamped[:, :, 0].max(dim=-1, keepdim=True)[0] + 1
+ ki = torch.div(clamped, k, rounding_mode="floor")
+ ki = ki[:, :, 0] + (max_x // k) * ki[:, :, 1]
+ weights = torch.nn.functional.one_hot(ki.long(), output_length).float() / (k * k)
+ x = (weights.transpose(1, 2) @ x.float()).to(x.dtype)
+
+ # Strip empty output tokens
+ valid_out = ~((weights == 0).all(dim=1))
+ if valid_out.any() and not valid_out.all():
+ x = x[:, valid_out[0]] if batch_size > 1 else x[valid_out].unsqueeze(0)
+
+ return x * self.root_hidden_size
+
+
+class Gemma4RMSNormProjector(nn.Module):
+ """Shared projector: parameterless RMSNorm → linear. Used for both vision and audio."""
+ def __init__(self, in_dim, out_dim, dtype=None, device=None, ops=None):
+ super().__init__()
+ self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype)
+
+ def forward(self, x):
+ return self.embedding_projection(rms_norm(x))
+
+
+class Gemma4MultiModalProjector(Gemma4RMSNormProjector):
+ def __init__(self, config, dtype=None, device=None, ops=None):
+ super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops)
+
+
+# Audio Encoder
+
+class Gemma4AudioConvSubsampler(nn.Module):
+ """2D convolution subsampling for audio features"""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ eps = config["rms_norm_eps"]
+ self.layer0 = nn.ModuleDict({
+ 'conv': ops.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
+ 'norm': ops.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
+ })
+ self.layer1 = nn.ModuleDict({
+ 'conv': ops.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
+ 'norm': ops.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
+ })
+ # proj_input_dim = (128 // 4) * 32 = 1024
+ self.input_proj_linear = ops.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype)
+
+ def _conv_layer(self, x, layer, mask):
+ if mask is not None:
+ x = x * mask[:, None, :, None].to(x.device)
+ x = layer['conv'](x.to(layer['conv'].weight.dtype))
+ x = torch.relu(layer['norm'](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous())
+ if mask is not None:
+ mask = mask[:, ::2]
+ return x, mask
+
+ def forward(self, x, mask=None):
+ x = x.unsqueeze(1)
+ x, mask = self._conv_layer(x, self.layer0, mask)
+ x, mask = self._conv_layer(x, self.layer1, mask)
+ batch_size, _, seq_len, _ = x.shape
+ x = x.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1)
+ return self.input_proj_linear(x), mask
+
+
+class Gemma4AudioFeedForward(nn.Module):
+ """Conformer feed-forward with residual scaling."""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ hidden_size = config["hidden_size"]
+ intermediate_size = config.get("intermediate_size", hidden_size * 4)
+ self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
+ self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops)
+ self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ self.post_layer_scale = config.get("residual_weight", 0.5)
+
+ def forward(self, x):
+ residual = x
+ x = self.pre_layer_norm(x)
+ x = torch.nn.functional.silu(self.ffw_layer_1(x))
+ x = self.ffw_layer_2(x)
+ x = self.post_layer_norm(x)
+ x = x * self.post_layer_scale
+ return x + residual
+
+
+class Gemma4AudioRelPositionalEncoding(nn.Module):
+ """Sinusoidal relative positional encoding for audio attention."""
+ def __init__(self, config, device=None, dtype=None):
+ super().__init__()
+ hidden_size = config["hidden_size"]
+ context_left = config.get("attention_context_left", 13)
+ context_right = config.get("attention_context_right", 0)
+ self.chunk_size = config.get("attention_chunk_size", 12)
+ self.context_size = self.chunk_size + context_left - 1 + context_right
+
+ num_timescales = hidden_size // 2
+ log_inc = math.log(10000.0) / max(num_timescales - 1, 1)
+ inv_timescales = torch.exp(torch.arange(num_timescales) * -log_inc).to(dtype=dtype).unsqueeze(0).unsqueeze(0)
+ self.register_buffer("inv_timescales", inv_timescales, persistent=False)
+
+ def forward(self, hidden_states):
+ positions = torch.arange(self.chunk_size, -1, -1, device=hidden_states.device).unsqueeze(-1)
+ scaled = positions * self.inv_timescales.to(device=hidden_states.device)
+ return torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1).to(dtype=hidden_states.dtype)
+
+
+class Gemma4AudioAttention(nn.Module):
+ """Chunked block attention with relative position bias and softcap."""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.hidden_size = config["hidden_size"]
+ self.num_heads = config["num_attention_heads"]
+ self.head_dim = self.hidden_size // self.num_heads
+ self.chunk_size = config.get("attention_chunk_size", 12)
+ self.max_past_horizon = config.get("attention_context_left", 13) - 1
+ self.max_future_horizon = config.get("attention_context_right", 0)
+ self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
+
+ self.q_scale = (self.head_dim ** -0.5) / math.log(2)
+ self.k_scale = math.log(1 + math.e) / math.log(2)
+ self.register_buffer("softcap", torch.tensor(config.get("attention_logit_cap", 50.0), dtype=dtype), persistent=False)
+
+ self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
+ self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
+ self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
+ self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
+ self.per_dim_scale = nn.Parameter(torch.empty(self.head_dim, device=device, dtype=dtype))
+ self.relative_k_proj = ops.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype)
+
+ def _convert_to_block(self, x):
+ B, S, H, D = x.shape
+ num_blocks = (S + self.chunk_size - 1) // self.chunk_size
+ pad = num_blocks * self.chunk_size - S
+ x = torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad))
+ return x.reshape(B, num_blocks, self.chunk_size, H, D).contiguous()
+
+ def _extract_block_context(self, x):
+ x = torch.nn.functional.pad(x, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1))
+ x = x.unfold(1, self.context_size, self.chunk_size)
+ return torch.movedim(x, -1, 2).contiguous()
+
+ def _rel_shift(self, x):
+ B, H, NB, BS, PL = x.shape
+ CS = self.context_size
+ x = torch.nn.functional.pad(x, (0, CS + 1 - PL))
+ x = x.view(B, H, NB, BS * (CS + 1))
+ x = x[..., :BS * CS]
+ return x.view(B, H, NB, BS, CS)
+
+ def _build_blocked_mask(self, seq_len, num_blocks, device, audio_mask=None):
+ """Build 5D boolean blocked attention mask (True=attend, False=mask)"""
+ q = torch.arange(seq_len, device=device)
+ dist = q[:, None] - q[None, :]
+ mask = (dist >= 0) & (dist < self.max_past_horizon)
+ if self.max_future_horizon > 0:
+ mask = mask | ((dist < 0) & ((-dist) < self.max_future_horizon))
+ if audio_mask is not None:
+ mask = mask & audio_mask[0, None, :].bool()
+ m = mask[None, None]
+ # Reshape to blocked 5D matching reference code
+ p = num_blocks * self.chunk_size - seq_len
+ m = torch.nn.functional.pad(m, (0, p, 0, p), value=False)
+ m = m.reshape(1, 1, num_blocks, self.chunk_size, -1)
+ m = torch.nn.functional.pad(m, (self.max_past_horizon, self.max_future_horizon), value=False)
+ idx = (torch.arange(num_blocks, device=device) * self.chunk_size)[:, None] + torch.arange(self.context_size, device=device)[None, :]
+ return m.gather(-1, idx[None, None, :, None, :].expand(1, 1, -1, self.chunk_size, -1))
+
+ def forward(self, x, position_embeddings=None, attn_mask=None):
+ B, S, _ = x.shape
+
+ q = self.q_proj(x).float().view(B, S, self.num_heads, self.head_dim)
+ k = self.k_proj(x).float().view(B, S, self.num_heads, self.head_dim)
+ v = self.v_proj(x).float().view(B, S, self.num_heads, self.head_dim)
+
+ q = q * self.q_scale * torch.nn.functional.softplus(self.per_dim_scale)
+ k = k * self.k_scale
+
+ q_blocks = self._convert_to_block(q)
+ k_context = self._extract_block_context(k)
+ v_context = self._extract_block_context(v)
+ num_blocks = q_blocks.shape[1]
+
+ rel_k = self.relative_k_proj(position_embeddings).view(-1, self.num_heads, self.head_dim).to(q.dtype)
+
+ queries = q_blocks.permute(0, 3, 1, 2, 4) # [B, H, NB, CS, D]
+ matrix_ac = queries @ k_context.permute(0, 3, 1, 4, 2)
+
+ queries_flat = queries.reshape(B, self.num_heads, -1, self.head_dim)
+ matrix_bd = queries_flat @ rel_k.permute(1, 2, 0)
+ matrix_bd = matrix_bd.reshape(B, self.num_heads, num_blocks, self.chunk_size, -1)
+ matrix_bd = self._rel_shift(matrix_bd)
+
+ attn_weights = matrix_ac + matrix_bd
+ attn_weights = torch.tanh(attn_weights / self.softcap) * self.softcap
+
+ # Mask out invalid positions in chunk context (matching reference's masked_fill approach)
+ if attn_mask is None:
+ attn_mask = self._build_blocked_mask(S, num_blocks, x.device)
+ attn_weights = attn_weights.masked_fill(attn_mask.logical_not(), -1e9)
+
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(v.dtype)
+ out = attn_weights @ v_context.permute(0, 3, 1, 2, 4)
+ out = out.permute(0, 2, 3, 1, 4).reshape(B, num_blocks * self.chunk_size, -1)
+ out = out[:, :S].contiguous()
+ return self.post(out.to(self.post.linear.weight.dtype))
+
+
+class Gemma4AudioLConv1d(nn.Module):
+ """Lightweight convolution with standard GLU."""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ hidden_size = config["hidden_size"]
+ conv_kernel_size = config.get("conv_kernel_size", 5)
+ self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops)
+ # Causal conv: left-pad only
+ self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype)
+ self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1
+ self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops)
+
+ def forward(self, x):
+ residual = x
+ x = self.pre_layer_norm(x)
+ x = self.linear_start(x)
+ x = torch.nn.functional.glu(x, dim=-1)
+ x = x.transpose(1, 2)
+ x = torch.nn.functional.pad(x, (self.conv_left_pad, 0))
+ x = self.depthwise_conv1d(x).transpose(1, 2)
+ x = self.conv_norm(x)
+ x = torch.nn.functional.silu(x)
+ x = self.linear_end(x)
+ return x + residual
+
+
+class Gemma4AudioLayer(nn.Module):
+ """Conformer block: FFN1 -> Attention -> LConv -> FFN2."""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops)
+ self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops)
+ norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ hidden_size = config["hidden_size"]
+ self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs)
+ self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs)
+ self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, ops=ops)
+ self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops)
+ self.norm_out = RMSNorm(hidden_size, **norm_kwargs)
+
+ def forward(self, x, position_embeddings=None, attn_mask=None):
+ x = self.feed_forward1(x)
+
+ residual = x
+ x = self.norm_pre_attn(x)
+ x = self.self_attn(x, position_embeddings=position_embeddings, attn_mask=attn_mask)
+ x = self.norm_post_attn(x)
+ x = x + residual
+
+ x = self.lconv1d(x)
+ x = self.feed_forward2(x)
+
+ x = self.norm_out(x)
+ return x
+
+
+class Gemma4AudioEncoder(nn.Module):
+ def __init__(self, config, dtype=None, device=None, ops=None):
+ super().__init__()
+ self.hidden_size = config["hidden_size"]
+ self.output_proj_dims = config.get("output_proj_dims", 1536)
+
+ self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, ops=ops)
+ self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config, device=device, dtype=dtype)
+
+ self.layers = nn.ModuleList([
+ Gemma4AudioLayer(config, device=device, dtype=dtype, ops=ops)
+ for _ in range(config["num_hidden_layers"])
+ ])
+
+ self.output_proj = ops.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype)
+
+ def forward(self, audio_features, audio_mask=None):
+ x, audio_mask = self.subsample_conv_projection(audio_features, audio_mask)
+ position_embeddings = self.rel_pos_enc(x)
+
+ # Build blocked attention mask once for all layers
+ attn_mask = self.layers[0].self_attn._build_blocked_mask(
+ x.shape[1], (x.shape[1] + self.layers[0].self_attn.chunk_size - 1) // self.layers[0].self_attn.chunk_size,
+ x.device, audio_mask=audio_mask)
+
+ for layer in self.layers:
+ x = layer(x, position_embeddings=position_embeddings, attn_mask=attn_mask)
+
+ x = self.output_proj(x)
+ return x
+
+
+class Gemma4AudioProjector(Gemma4RMSNormProjector):
+ def __init__(self, config, dtype=None, device=None, ops=None):
+ super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, ops=ops)
+
+
+# Tokenizer and Wrappers
+
+class Gemma4_Tokenizer():
+ tokenizer_json_data = None
+
+ def state_dict(self):
+ if self.tokenizer_json_data is not None:
+ return {"tokenizer_json": self.tokenizer_json_data}
+ return {}
+
+ def _extract_mel_spectrogram(self, waveform, sample_rate):
+ """Extract 128-bin log mel spectrogram.
+ Uses numpy for FFT/matmul/log to produce bit-identical results with reference code.
+ """
+ # Mix to mono first, then resample to 16kHz
+ if waveform.dim() > 1 and waveform.shape[0] > 1:
+ waveform = waveform.mean(dim=0, keepdim=True)
+ if waveform.dim() == 1:
+ waveform = waveform.unsqueeze(0)
+ audio = waveform.squeeze(0).float().numpy()
+ if sample_rate != 16000:
+ # Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match)
+ from scipy.signal import resample_poly, firwin
+ from math import gcd
+ g = gcd(sample_rate, 16000)
+ up, down = 16000 // g, sample_rate // g
+ L = max(up, down)
+ h = firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5))
+ audio = resample_poly(audio, up, down, window=h).astype(np.float32)
+ n = len(audio)
+
+ # Pad to multiple of 128, build sample-level mask
+ if n % 128 != 0:
+ audio = np.pad(audio, (0, 128 - n % 128))
+ mask_raw = np.ones(len(audio), dtype=np.float32)
+ mask_raw[n:] = 0.0
+
+ # Semicausal padding: 160 zeros prepended
+ audio = np.pad(audio, (160, 0))
+ mask_raw = np.pad(mask_raw, (160, 0))
+
+ # Extract 321-sample frames via stride tricks, drop last → 320
+ nf = (len(audio) - 321) // 160 + 1
+ strides = (audio.strides[0] * 160, audio.strides[0])
+ frames = np.lib.stride_tricks.as_strided(audio, (nf, 321), strides)[..., :-1].copy()
+
+ # Periodic Hann window, FFT magnitude, mel filterbank, log
+ window = (0.5 - 0.5 * np.cos(2 * np.pi * np.arange(320) / 320)).astype(np.float32)
+ magnitude = np.abs(np.fft.rfft(frames * window, n=512, axis=-1))
+ mel_fb = self._build_mel_filterbank()
+ log_mel = np.log(np.matmul(magnitude, mel_fb) + np.float64(0.001)).astype(np.float32)
+
+ # Frame mask: valid when last sample in window is real audio
+ mask = mask_raw[np.arange(nf) * 160 + 320].astype(bool)
+ log_mel = log_mel * mask[:, None]
+ return torch.from_numpy(log_mel), torch.from_numpy(mask) # [T, 128], [T]
+
+ @staticmethod
+ def _build_mel_filterbank():
+ """Build 128-bin HTK mel filterbank [257, 128] for 512-pt FFT at 16kHz."""
+ mel_freqs = np.linspace(0.0, 2595.0 * np.log10(1.0 + 8000.0 / 700.0), 130)
+ filter_freqs = 700.0 * (10.0 ** (mel_freqs / 2595.0) - 1.0)
+ fft_freqs = np.linspace(0, 16000 // 2, 257)
+ filter_diff = np.diff(filter_freqs)
+ slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
+ return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
+
+ def tokenize_with_weights(self, text, return_word_ids=False, image=None, audio=None, video=None, llama_template=None, skip_template=True, thinking=False, **kwargs):
+
+ # Process audio
+ audio_features = []
+ if audio is not None:
+ waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio
+ sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000
+ mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate)
+ audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T])
+
+ # Process image/video frames
+ is_video = video is not None
+ source = video if is_video else image
+ images = []
+ if source is not None:
+ samples = source.movedim(-1, 1) # [B, C, H, W]
+ num_frames = samples.shape[0]
+
+ # Subsample video to 1fps
+ if is_video:
+ fps = kwargs.get("fps", 24)
+ step = max(1, round(fps))
+ indices = list(range(0, num_frames, step))
+ if len(indices) == 0:
+ indices = [0]
+ samples = samples[indices]
+ num_frames = len(indices)
+
+ h, w = samples.shape[2], samples.shape[3]
+ patch_size = 16
+ pooling_k = 3
+ max_soft_tokens = 70 if is_video else 280 # video uses smaller token budget per frame
+ max_patches = max_soft_tokens * pooling_k * pooling_k
+ target_px = max_patches * patch_size * patch_size
+ factor = (target_px / (h * w)) ** 0.5
+ side_mult = pooling_k * patch_size
+ target_h = max(int(factor * h // side_mult) * side_mult, side_mult)
+ target_w = max(int(factor * w // side_mult) * side_mult, side_mult)
+
+ import torchvision.transforms.functional as TVF
+ for i in range(num_frames):
+ # rescaling to match reference code
+ s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8
+ if target_h != h or target_w != w:
+ s = TVF.resize(s, [target_h, target_w], interpolation=TVF.InterpolationMode.BICUBIC, antialias=True)
+ s = s.float() * (1.0 / 255.0)
+ images.append({"pixels": s.unsqueeze(0).movedim(1, -1)[:, :, :, :3], "max_soft_tokens": max_soft_tokens})
+
+ if text.startswith('<|turn>'):
+ skip_template = True
+
+ if skip_template:
+ llama_text = text
+ else:
+ if llama_template is not None:
+ llama_text = llama_template.format(text)
+ else:
+ # Build template from modalities present
+ system = "<|turn>system\n<|think|>\n" if thinking else ""
+ media = ""
+ if len(images) > 0:
+ if is_video:
+ media += "\n\n"
+ for i in range(len(images)):
+ ts = f"{int(i // 60):02d}:{int(i % 60):02d}"
+ sep = "" if i == 0 else " "
+ media += f"{sep}{ts} <|image><|video|>"
+ media += "\n\n"
+ else:
+ media += "\n\n"
+ for i in range(len(images)):
+ if i > 0:
+ media += "\n\n\n\n"
+ media += "<|image><|image|>"
+ media += "\n\n"
+ if len(audio_features) > 0:
+ # Compute audio token count (always at 16kHz)
+ num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1]
+ _fl = 320 # int(round(16000 * 20.0 / 1000.0))
+ _hl = 160 # int(round(16000 * 10.0 / 1000.0))
+ _nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1
+ _t = _nmel
+ for _ in range(2):
+ _t = (_t + 2 - 3) // 2 + 1
+ n_audio_tokens = min(_t, 750)
+ media += "<|audio>" + "<|audio|>" * n_audio_tokens + "