diff --git a/app/assets/services/metadata_extract.py b/app/assets/services/metadata_extract.py
index a004929bc..bdfe60218 100644
--- a/app/assets/services/metadata_extract.py
+++ b/app/assets/services/metadata_extract.py
@@ -4,7 +4,6 @@ Tier 1: Filesystem metadata (zero parsing)
Tier 2: Safetensors header metadata (fast JSON read only)
"""
-from __future__ import annotations
import json
import logging
diff --git a/app/custom_node_manager.py b/app/custom_node_manager.py
index 281febca9..738af2abd 100644
--- a/app/custom_node_manager.py
+++ b/app/custom_node_manager.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import os
import folder_paths
import glob
diff --git a/app/frontend_management.py b/app/frontend_management.py
index 483da2d29..8e84e8dd9 100644
--- a/app/frontend_management.py
+++ b/app/frontend_management.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
import argparse
import logging
import os
diff --git a/app/model_manager.py b/app/model_manager.py
index f124d1117..8f6e34b33 100644
--- a/app/model_manager.py
+++ b/app/model_manager.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import os
import base64
import json
diff --git a/app/user_manager.py b/app/user_manager.py
index 0517b3344..7b11e381c 100644
--- a/app/user_manager.py
+++ b/app/user_manager.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
import json
import os
import re
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 47b8174f4..9bda414d1 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
-parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.")
+parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py
index 57126fa4a..bb21eb1d1 100644
--- a/comfy/comfy_types/node_typing.py
+++ b/comfy/comfy_types/node_typing.py
@@ -1,6 +1,5 @@
"""Comfy-specific type hinting"""
-from __future__ import annotations
from typing import Literal, TypedDict, Optional
from typing_extensions import NotRequired
from abc import ABC, abstractmethod
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index ba670b16d..6dbbaa959 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -15,13 +15,14 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see .
"""
-
+from __future__ import annotations
import torch
from enum import Enum
import math
import os
import logging
+import copy
import comfy.utils
import comfy.model_management
import comfy.model_detection
@@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from comfy.hooks import HookGroup
@@ -64,6 +65,18 @@ class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
+class ControlIsolation:
+ '''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
+ def __init__(self, control: ControlBase):
+ self.control = control
+ self.orig_previous_controlnet = control.previous_controlnet
+
+ def __enter__(self):
+ self.control.previous_controlnet = None
+
+ def __exit__(self, *args):
+ self.control.previous_controlnet = self.orig_previous_controlnet
+
class ControlBase:
def __init__(self):
self.cond_hint_original = None
@@ -77,7 +90,7 @@ class ControlBase:
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
- self.previous_controlnet = None
+ self.previous_controlnet: Union[ControlBase, None] = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
@@ -85,6 +98,7 @@ class ControlBase:
self.extra_concat = None
self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a
+ self.multigpu_clones: dict[torch.device, ControlBase] = {}
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@@ -111,17 +125,38 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
-
+ for device_cnet in self.multigpu_clones.values():
+ with ControlIsolation(device_cnet):
+ device_cnet.cleanup()
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
out = []
+ for device_cnet in self.multigpu_clones.values():
+ out += device_cnet.get_models_only_self()
if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models()
return out
+ def get_models_only_self(self):
+ 'Calls get_models, but temporarily sets previous_controlnet to None.'
+ with ControlIsolation(self):
+ return self.get_models()
+
+ def get_instance_for_device(self, device):
+ 'Returns instance of this Control object intended for selected device.'
+ return self.multigpu_clones.get(device, self)
+
+ def deepclone_multigpu(self, load_device, autoregister=False):
+ '''
+ Create deep clone of Control object where model(s) is set to other devices.
+
+ When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
+ '''
+ raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
+
def get_extra_hooks(self):
out = []
if self.extra_hooks is not None:
@@ -130,7 +165,7 @@ class ControlBase:
out += self.previous_controlnet.get_extra_hooks()
return out
- def copy_to(self, c):
+ def copy_to(self, c: ControlBase):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
@@ -284,6 +319,14 @@ class ControlNet(ControlBase):
self.copy_to(c)
return c
+ def deepclone_multigpu(self, load_device, autoregister=False):
+ c = self.copy()
+ c.control_model = copy.deepcopy(c.control_model)
+ c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
+ if autoregister:
+ self.multigpu_clones[load_device] = c
+ return c
+
def get_models(self):
out = super().get_models()
out.append(self.control_model_wrapped)
@@ -314,6 +357,10 @@ class QwenFunControlNet(ControlNet):
super().pre_run(model, percent_to_timestep_function)
self.set_extra_arg("base_model", model.diffusion_model)
+ def cleanup(self):
+ self.extra_args.pop("base_model", None)
+ super().cleanup()
+
def copy(self):
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
@@ -906,6 +953,14 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
+ def deepclone_multigpu(self, load_device, autoregister=False):
+ c = self.copy()
+ c.t2i_model = copy.deepcopy(c.t2i_model)
+ c.device = load_device
+ if autoregister:
+ self.multigpu_clones[load_device] = c
+ return c
+
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index 75d459b59..12a934d71 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -799,13 +799,15 @@ class ZImagePixelSpace(ChromaRadiance):
"""
pass
-
class HiDreamO1Pixel(ChromaRadiance):
"""Pixel-space latent format for HiDream-O1.
No VAE — model patches/unpatches raw RGB internally with patch_size=32.
"""
pass
+class PixelDiTPixel(ChromaRadiance):
+ pass
+
class CogVideoX(LatentFormat):
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
index bc36b8998..4e4819fe3 100644
--- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
+++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py
@@ -607,9 +607,13 @@ class HunYuanDiTPlain(nn.Module):
def forward(self, x, t, context, transformer_options = {}, **kwargs):
x = x.movedim(-1, -2)
- if context.shape[0] >= 2:
- uncond_emb, cond_emb = context.chunk(2, dim = 0)
- context = torch.cat([cond_emb, uncond_emb], dim = 0)
+
+ swap_cfg_halves = context.shape[0] >= 2
+
+ if swap_cfg_halves:
+ first_half, second_half = context.chunk(2, dim = 0)
+ context = torch.cat([second_half, first_half], dim = 0)
+
main_condition = context
t = 1.0 - t
@@ -657,8 +661,8 @@ class HunYuanDiTPlain(nn.Module):
output = self.final_layer(combined)
output = output.movedim(-2, -1) * (-1.0)
- if output.shape[0] >= 2:
- cond_emb, uncond_emb = output.chunk(2, dim = 0)
- return torch.cat([uncond_emb, cond_emb])
- else:
- return output
+ if swap_cfg_halves:
+ first_half, second_half = output.chunk(2, dim = 0)
+ output = torch.cat([second_half, first_half], dim = 0)
+
+ return output
diff --git a/comfy/ldm/lens/model.py b/comfy/ldm/lens/model.py
new file mode 100644
index 000000000..cd5015ddc
--- /dev/null
+++ b/comfy/ldm/lens/model.py
@@ -0,0 +1,510 @@
+"""Lens denoising transformer (DiT)"""
+
+from __future__ import annotations
+
+from typing import Any, Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import comfy.ldm.flux.layers
+import comfy.patcher_extension
+from comfy.ldm.flux.layers import EmbedND
+from comfy.ldm.flux.math import apply_rope
+from comfy.ldm.modules.attention import optimized_attention
+
+
+def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
+ return comfy.ldm.flux.layers.timestep_embedding(t, dim)
+
+
+def _lens_position_ids(
+ frame: int, height: int, width: int, text_seq_len: int,
+ scale_rope: bool = True, device=None,
+) -> torch.Tensor:
+ """Lens axial (frame, h, w) position ids for joint image + text sequence.
+
+ With ``scale_rope=True`` h/w are centered around 0 (negative + positive
+ halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
+ caller adds a batch dim for ``EmbedND``.
+ """
+ if scale_rope:
+ h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
+ torch.arange(0, height // 2, device=device)])
+ w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
+ torch.arange(0, width // 2, device=device)])
+ text_start = max(height // 2, width // 2)
+ else:
+ h_pos = torch.arange(height, device=device)
+ w_pos = torch.arange(width, device=device)
+ text_start = max(height, width)
+
+ f_pos = torch.arange(frame, device=device)
+ img_ids = torch.zeros(frame, height, width, 3, device=device)
+ img_ids[..., 0] = f_pos[:, None, None]
+ img_ids[..., 1] = h_pos[None, :, None]
+ img_ids[..., 2] = w_pos[None, None, :]
+ img_ids = img_ids.reshape(-1, 3)
+
+ # Text positions replicate across all 3 axes (matches original packing).
+ txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
+ txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
+
+ return torch.cat([img_ids, txt_ids], dim=0)
+
+
+class _TimestepEmbedder(nn.Module):
+ def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None:
+ super().__init__()
+ self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
+ self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.linear_1(x)
+ x = F.silu(x)
+ return self.linear_2(x)
+
+
+class LensTimestepProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None:
+ super().__init__()
+ self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations)
+
+ def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
+ proj = _lens_time_proj(timestep, 256)
+ return self.timestep_embedder(proj.to(dtype=hidden_states.dtype))
+
+
+class GateMLP(nn.Module):
+ """SwiGLU MLP."""
+
+ def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None:
+ super().__init__()
+ self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
+ self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
+ self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
+
+ def forward(self, x):
+ return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x)))
+
+
+class LensJointAttention(nn.Module):
+ """Joint image+text attention with fused QKV per stream."""
+
+ def __init__(
+ self,
+ query_dim: int,
+ added_kv_proj_dim: int,
+ dim_head: int = 64,
+ heads: int = 8,
+ out_dim: Optional[int] = None,
+ eps: float = 1e-5,
+ dtype=None,
+ device=None,
+ operations=None,
+ ) -> None:
+ super().__init__()
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.heads = self.inner_dim // dim_head
+ self.dim_head = dim_head
+ self.out_dim = out_dim if out_dim is not None else query_dim
+
+ self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
+ self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
+ self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
+ self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
+
+ self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
+ self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
+
+ # ModuleList([Linear, Identity]) for state-dict key compatibility.
+ self.to_out = nn.ModuleList([
+ operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device),
+ nn.Identity(),
+ ])
+ self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ transformer_options: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ bsz, seq_img, _ = hidden_states.shape
+ seq_txt = encoder_hidden_states.shape[1]
+
+ # image stream
+ img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head)
+ img_q, img_k, img_v = img_qkv.unbind(dim=2)
+ img_q = self.norm_q(img_q)
+ img_k = self.norm_k(img_k)
+ del img_qkv
+
+ # text stream
+ txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head)
+ txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
+ txt_q = self.norm_added_q(txt_q)
+ txt_k = self.norm_added_k(txt_k)
+
+ # [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
+ q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
+ del img_q, txt_q
+ k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2)
+ del img_k, txt_k
+ v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
+ del img_v, txt_v
+
+ q, k = apply_rope(q, k, freqs_cis)
+
+ if attention_mask is not None:
+ expected = (bsz, 1, 1, seq_img + seq_txt)
+ if attention_mask.shape != expected:
+ raise ValueError(
+ f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}"
+ )
+ attention_mask = attention_mask.to(q.dtype)
+
+ out = optimized_attention(
+ q, k, v, self.heads, mask=attention_mask, skip_reshape=True,
+ transformer_options=transformer_options,
+ )
+
+ img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :]))
+ txt_out = self.to_add_out(out[:, seq_img:, :])
+ return img_out, txt_out
+
+
+class LensTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ eps: float = 1e-6,
+ rms_norm: bool = True,
+ dtype=None,
+ device=None,
+ operations=None,
+ ) -> None:
+ super().__init__()
+
+ self.attn = LensJointAttention(
+ query_dim=dim,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ eps=1e-5,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+
+ if rms_norm:
+ NormCls = operations.RMSNorm
+ norm_kwargs = {}
+ else:
+ NormCls = operations.LayerNorm
+ norm_kwargs = {"elementwise_affine": False}
+
+ mlp_hidden = int(dim / 3 * 8)
+
+ # Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}.
+ self.img_mod = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
+ )
+ self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
+ self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
+ self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
+
+ self.txt_mod = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
+ )
+ self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
+ self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
+ self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
+
+ @staticmethod
+ def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ shift, scale, gate = mod_params.chunk(3, dim=-1)
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ transformer_options: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1)
+ txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1)
+
+ img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
+ txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
+
+ img_attn, txt_attn = self.attn(
+ hidden_states=img_modulated,
+ encoder_hidden_states=txt_modulated,
+ freqs_cis=freqs_cis,
+ attention_mask=attention_mask,
+ transformer_options=transformer_options,
+ )
+
+ hidden_states = hidden_states + img_gate1 * img_attn
+ encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn
+
+ img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
+ hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
+
+ txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
+ encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
+
+ return encoder_hidden_states, hidden_states
+
+
+class _AdaLayerNormContinuousNoAffine(nn.Module):
+ """AdaLayerNormContinuous(elementwise_affine=False).
+
+ The reference uses ``scale, shift = chunk(2)`` (scale first) — opposite
+ to Flux's ``LastLayer``.
+ """
+
+ def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6,
+ dtype=None, device=None, operations=None) -> None:
+ super().__init__()
+ self.linear = operations.Linear(
+ conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device
+ )
+ self.eps = eps
+ self.embedding_dim = embedding_dim
+
+ def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
+ emb = self.linear(F.silu(conditioning))
+ scale, shift = torch.chunk(emb, 2, dim=-1)
+ x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps)
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+class LensTransformer2DModel(nn.Module):
+ """Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text)."""
+
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 128,
+ out_channels: Optional[int] = 32,
+ num_layers: int = 48,
+ attention_head_dim: int = 64,
+ num_attention_heads: int = 24,
+ enc_hidden_dim: int = 2880,
+ axes_dims_rope: Tuple[int, int, int] = (8, 28, 28),
+ rms_norm: bool = True,
+ multi_layer_encoder_feature: bool = True,
+ selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23),
+ image_model=None, # unused; accepted for detection-side configs.
+ dtype=None,
+ device=None,
+ operations=None,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels if out_channels is not None else in_channels
+ self.inner_dim = num_attention_heads * attention_head_dim
+ self.multi_layer_encoder_feature = multi_layer_encoder_feature
+ self.selected_layer_index = list(selected_layer_index)
+ self.dtype = dtype
+
+ self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
+ self.time_text_embed = LensTimestepProjEmbeddings(
+ embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
+ )
+
+ if self.multi_layer_encoder_feature:
+ self.txt_norm = nn.ModuleList(
+ [operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
+ for _ in self.selected_layer_index]
+ )
+ self.txt_in = operations.Linear(
+ enc_hidden_dim * len(self.selected_layer_index),
+ self.inner_dim, bias=True, dtype=dtype, device=device,
+ )
+ else:
+ self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
+ self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device)
+
+ self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
+
+ self.transformer_blocks = nn.ModuleList([
+ LensTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ eps=1e-6,
+ rms_norm=rms_norm,
+ dtype=dtype, device=device, operations=operations,
+ )
+ for _ in range(num_layers)
+ ])
+
+ self.norm_out = _AdaLayerNormContinuousNoAffine(
+ self.inner_dim, self.inner_dim, eps=1e-6,
+ dtype=dtype, device=device, operations=operations,
+ )
+ self.proj_out = operations.Linear(
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True,
+ dtype=dtype, device=device,
+ )
+
+ def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
+ transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor:
+ if transformer_options is None:
+ transformer_options = {}
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward, self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
+ ).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
+
+ def _forward(
+ self,
+ x: torch.Tensor,
+ timestep: torch.Tensor,
+ context: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ transformer_options: Optional[Dict[str, Any]] = None,
+ control: Optional[Dict[str, Any]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``."""
+ if transformer_options is None:
+ transformer_options = {}
+ transformer_options = transformer_options.copy()
+ patches = transformer_options.get("patches", {})
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+
+ B, C, h, w = x.shape
+ hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C)
+
+ if self.multi_layer_encoder_feature:
+ L = len(self.selected_layer_index)
+ enc_dim = context.shape[-1] // L
+ encoder_hidden_states = list(
+ context.reshape(B, -1, L, enc_dim).unbind(dim=2)
+ )
+ text_seq_len = encoder_hidden_states[0].shape[1]
+ else:
+ encoder_hidden_states = context
+ text_seq_len = context.shape[1]
+
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (B, text_seq_len), dtype=torch.bool, device=x.device
+ )
+
+ img_len = h * w
+ joint_mask = self._build_joint_attention_mask(attention_mask, img_len)
+
+ hidden_states = self.img_in(hidden_states)
+ timestep = timestep.to(hidden_states.dtype)
+
+ if self.multi_layer_encoder_feature:
+ normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)]
+ encoder_hidden_states = torch.cat(normed, dim=-1)
+ else:
+ encoder_hidden_states = self.txt_norm(encoder_hidden_states)
+ encoder_hidden_states = self.txt_in(encoder_hidden_states)
+
+ if "post_input" in patches:
+ for p in patches["post_input"]:
+ out = p({
+ "img": hidden_states,
+ "txt": encoder_hidden_states,
+ "transformer_options": transformer_options,
+ })
+ hidden_states = out["img"]
+ encoder_hidden_states = out["txt"]
+
+ temb = self.time_text_embed(timestep, hidden_states)
+ ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
+ freqs_cis = self.pos_embed(ids)
+
+ transformer_options["total_blocks"] = len(self.transformer_blocks)
+ transformer_options["block_type"] = "double"
+ for i, block in enumerate(self.transformer_blocks):
+ transformer_options["block_index"] = i
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["txt"], out["img"] = block(
+ hidden_states=args["img"],
+ encoder_hidden_states=args["txt"],
+ temb=args["vec"],
+ freqs_cis=args["pe"],
+ attention_mask=args.get("attn_mask"),
+ transformer_options=args.get("transformer_options"),
+ )
+ return out
+ out = blocks_replace[("double_block", i)](
+ {
+ "img": hidden_states,
+ "txt": encoder_hidden_states,
+ "vec": temb,
+ "pe": freqs_cis,
+ "attn_mask": joint_mask,
+ "transformer_options": transformer_options,
+ },
+ {"original_block": block_wrap},
+ )
+ encoder_hidden_states = out["txt"]
+ hidden_states = out["img"]
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ freqs_cis=freqs_cis,
+ attention_mask=joint_mask,
+ transformer_options=transformer_options,
+ )
+
+ if "double_block" in patches:
+ for p in patches["double_block"]:
+ out = p({
+ "img": hidden_states,
+ "txt": encoder_hidden_states,
+ "x": x,
+ "block_index": i,
+ "transformer_options": transformer_options,
+ })
+ hidden_states = out["img"]
+ encoder_hidden_states = out["txt"]
+
+ if control is not None:
+ control_i = control.get("input")
+ if control_i is not None and i < len(control_i):
+ add = control_i[i]
+ if add is not None:
+ hidden_states[:, :add.shape[1]] += add
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ out = self.proj_out(hidden_states)
+ return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous()
+
+ @staticmethod
+ def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor:
+ if text_mask.dtype != torch.bool:
+ text_mask = text_mask.bool()
+ bsz = text_mask.shape[0]
+ img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device)
+ joint = torch.cat([img_ones, text_mask], dim=1)
+ additive = torch.zeros_like(joint, dtype=torch.float32)
+ additive.masked_fill_(~joint, torch.finfo(torch.float32).min)
+ return additive[:, None, None, :]
diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py
index bc09fb77e..ef9938465 100644
--- a/comfy/ldm/lightricks/av_model.py
+++ b/comfy/ldm/lightricks/av_model.py
@@ -767,25 +767,25 @@ class LTXAVModel(LTXVModel):
# Cross-attention timesteps - compress these too
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
- timestep.max().expand_as(a_timestep_flat),
+ a_timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
- a_timestep.max().expand_as(timestep_flat),
+ timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
- a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
+ a_timestep_scaled.max().expand_as(timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
- timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
+ timestep_scaled.max().expand_as(a_timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
index b556b128f..58b67d45a 100644
--- a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
+++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
import torch
from torch import nn
from torch.nn import functional as F
diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
index 998122c85..5975015e2 100644
--- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
+++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
import threading
import torch
from torch import nn
diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py
index 9e432d5c0..d0ee97d33 100644
--- a/comfy/ldm/lumina/model.py
+++ b/comfy/ldm/lumina/model.py
@@ -1,5 +1,4 @@
# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
-from __future__ import annotations
from typing import List, Optional, Tuple
diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py
index 0dc8fe789..9ab3c463c 100644
--- a/comfy/ldm/modules/diffusionmodules/mmdit.py
+++ b/comfy/ldm/modules/diffusionmodules/mmdit.py
@@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations.
"""
- def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
+ def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000):
super().__init__()
if output_size is None:
output_size = hidden_size
@@ -221,9 +221,10 @@ class TimestepEmbedder(nn.Module):
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
+ self.max_period = max_period
def forward(self, t, dtype, **kwargs):
- t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
diff --git a/comfy/ldm/moge/geometry.py b/comfy/ldm/moge/geometry.py
index 7fdc97871..d1a1e445f 100644
--- a/comfy/ldm/moge/geometry.py
+++ b/comfy/ldm/moge/geometry.py
@@ -1,6 +1,5 @@
"""Pure-torch + scipy geometry helpers for MoGe inference and mesh export."""
-from __future__ import annotations
from typing import Optional, Tuple
diff --git a/comfy/ldm/moge/model.py b/comfy/ldm/moge/model.py
index 6876c4af2..1695626bc 100644
--- a/comfy/ldm/moge/model.py
+++ b/comfy/ldm/moge/model.py
@@ -4,7 +4,6 @@ V1: DINOv2 backbone + multi-output head (points, mask).
V2: DINOv2 encoder + neck + per-output heads (points, mask, normal, optional metric-scale MLP).
"""
-from __future__ import annotations
from numbers import Number
from typing import Any, Dict, List, Optional, Tuple, Union
diff --git a/comfy/ldm/moge/modules.py b/comfy/ldm/moge/modules.py
index 235a59212..f6443d65a 100644
--- a/comfy/ldm/moge/modules.py
+++ b/comfy/ldm/moge/modules.py
@@ -1,6 +1,5 @@
"""Building blocks for MoGe: residual conv stack, resamplers, MLP, DINOv2 encoder, v1 head."""
-from __future__ import annotations
from typing import List, Optional, Sequence, Tuple, Union
diff --git a/comfy/ldm/moge/panorama.py b/comfy/ldm/moge/panorama.py
index de53ebe68..18d0cb665 100644
--- a/comfy/ldm/moge/panorama.py
+++ b/comfy/ldm/moge/panorama.py
@@ -6,7 +6,6 @@ equirect distance map via a multi-scale Poisson + gradient sparse solve.
Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU).
"""
-from __future__ import annotations
from typing import Callable, List, Optional, Tuple
diff --git a/comfy/ldm/pixeldit/model.py b/comfy/ldm/pixeldit/model.py
new file mode 100644
index 000000000..b044b9b29
--- /dev/null
+++ b/comfy/ldm/pixeldit/model.py
@@ -0,0 +1,239 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import comfy.ldm.common_dit
+import comfy.patcher_extension
+from comfy.ldm.flux.math import apply_rope, rope
+from comfy.ldm.hidream.model import FeedForwardSwiGLU
+from comfy.ldm.modules.attention import optimized_attention
+from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
+
+from .modules import (
+ FinalLayer,
+ PatchTokenEmbedder,
+ PiTBlock,
+ PixelTokenEmbedder,
+ apply_adaln_,
+ precompute_freqs_cis_2d,
+)
+
+
+class MMDiTJointAttention(nn.Module):
+ """Joint MMDiT attention with separate Q/K/V/proj for image and text streams.
+
+ RoPE is applied to each stream before concatenation so each stream uses its own
+ 2D/1D positional encoding. Concat order is [text, image] (text first).
+ """
+ def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
+ super().__init__()
+ assert dim % num_heads == 0
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+
+ self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
+ self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
+
+ self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
+ self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
+ self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
+ self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
+
+ self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device)
+ self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device)
+
+ def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
+ B, Nx, _ = x.shape
+ _, Ny, _ = y.shape
+ H = self.num_heads
+ D = self.head_dim
+
+ qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4)
+ qx, kx, vx = qkv_x.unbind(0)
+ qx = self.q_norm_x(qx)
+ kx = self.k_norm_x(kx)
+
+ qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4)
+ qy, ky, vy = qkv_y.unbind(0)
+ qy = self.q_norm_y(qy)
+ ky = self.k_norm_y(ky)
+
+ qx, kx = apply_rope(qx, kx, pos_img[None, None])
+ if pos_txt is not None:
+ qy, ky = apply_rope(qy, ky, pos_txt[None, None])
+
+ q_joint = torch.cat([qy, qx], dim=2)
+ k_joint = torch.cat([ky, kx], dim=2)
+ v_joint = torch.cat([vy, vx], dim=2)
+
+ out_joint = optimized_attention(
+ q_joint, k_joint, v_joint, H,
+ mask=attn_mask, skip_reshape=True, skip_output_reshape=True,
+ transformer_options=transformer_options,
+ )
+
+ out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D)
+ out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D)
+
+ return self.proj_x(out_x), self.proj_y(out_y)
+
+
+class MMDiTBlockT2I(nn.Module):
+ def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
+ self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
+ self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, dtype=dtype, device=device, operations=operations)
+ self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
+ self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
+ self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
+ self.adaLN_modulation_img = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
+ self.adaLN_modulation_txt = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
+
+ def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
+ shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
+ shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
+
+ x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x)
+ y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y)
+ attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
+ x = torch.addcmul(x, gate_msa_x, attn_x)
+ y = torch.addcmul(y, gate_msa_y, attn_y)
+
+ x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
+ y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
+ return x, y
+
+
+class PixDiT_T2I(nn.Module):
+ """PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint
+ (also runs at 512px when fed the appropriate latent size and flow_shift).
+
+ Forward:
+ x: [B, 3, H, W] pixel-space input (no VAE)
+ timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention)
+ context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended)
+ Returns flow-matching velocity [B, 3, H, W].
+ """
+ def __init__(
+ self,
+ in_channels=3,
+ num_groups=24,
+ hidden_size=1536,
+ pixel_hidden_size=16,
+ pixel_attn_hidden_size=1152,
+ pixel_num_groups=16,
+ patch_depth=14,
+ pixel_depth=2,
+ patch_size=16,
+ txt_embed_dim=2304,
+ txt_max_length=300,
+ use_text_rope=True,
+ text_rope_theta=10000.0,
+ image_model=None,
+ dtype=None,
+ device=None,
+ operations=None,
+ pixel_mlp_chunks=2,
+ ):
+ super().__init__()
+ self.dtype = dtype
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.hidden_size = hidden_size
+ self.num_groups = num_groups
+ self.patch_depth = patch_depth
+ self.pixel_depth = pixel_depth
+ self.patch_size = patch_size
+ self.pixel_hidden_size = pixel_hidden_size
+ self.pixel_attn_hidden_size = pixel_attn_hidden_size
+ self.pixel_num_groups = pixel_num_groups
+ self.txt_embed_dim = txt_embed_dim
+ self.txt_max_length = txt_max_length
+ self.use_text_rope = use_text_rope
+ self.text_rope_theta = text_rope_theta
+
+ self.pixel_embedder = PixelTokenEmbedder(self.in_channels, self.pixel_hidden_size, dtype=dtype, device=device, operations=operations)
+ self.s_embedder = PatchTokenEmbedder(self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, dtype=dtype, device=device, operations=operations)
+ self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10)
+ self.y_embedder = PatchTokenEmbedder(self.txt_embed_dim, self.hidden_size, bias=True, use_norm=True, dtype=dtype, device=device, operations=operations)
+ self.y_pos_embedding = nn.Parameter(torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device))
+
+ self.patch_blocks = nn.ModuleList([
+ MMDiTBlockT2I(self.hidden_size, self.num_groups,
+ dtype=dtype, device=device, operations=operations)
+ for _ in range(self.patch_depth)
+ ])
+ self.pixel_blocks = nn.ModuleList([
+ PiTBlock(
+ self.pixel_hidden_size,
+ self.hidden_size,
+ patch_size=self.patch_size,
+ num_heads=self.num_groups,
+ attn_hidden_size=self.pixel_attn_hidden_size,
+ attn_num_heads=self.pixel_num_groups,
+ dtype=dtype, device=device, operations=operations,
+ mlp_chunks=pixel_mlp_chunks,
+ )
+ for _ in range(self.pixel_depth)
+ ])
+
+ self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, dtype=dtype, device=device, operations=operations)
+
+ def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
+ return precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width, device=device, dtype=dtype, **rope_opts)
+
+ def _fetch_text_pos(self, length, device, dtype):
+ return rope(torch.arange(length, dtype=torch.float32, device=device).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0).to(dtype=dtype)
+
+ def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
+ ).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
+
+ def _pre_patch_block(self, s, i, **kwargs):
+ """Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate)."""
+ return s
+
+ def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
+ H_orig, W_orig = x.shape[2], x.shape[3]
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
+ B, _, H, W = x.shape
+ Hs = H // self.patch_size
+ Ws = W // self.patch_size
+ L = Hs * Ws
+
+ pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
+ x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
+
+ t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
+
+ if context is None or context.dim() != 3:
+ raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]")
+ Ltxt = min(context.shape[1], self.txt_max_length)
+ y = context[:, :Ltxt, :]
+ y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
+ y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter
+
+ condition = F.silu(t_emb)
+ pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
+
+ s = self.s_embedder(x_patches)
+ for i, blk in enumerate(self.patch_blocks):
+ s = self._pre_patch_block(s, i, **kwargs)
+ s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
+ s = F.silu(t_emb + s)
+
+ s_cond = s.view(B * L, self.hidden_size)
+ x_pixels = self.pixel_embedder(x, patch_size=self.patch_size)
+ for blk in self.pixel_blocks:
+ x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options)
+
+ x_pixels = self.final_layer(x_pixels)
+ C_out = self.out_channels
+ P2 = self.patch_size * self.patch_size
+ x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).reshape(B, C_out * P2, L)
+ out = F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
+ return out[:, :, :H_orig, :W_orig]
diff --git a/comfy/ldm/pixeldit/modules.py b/comfy/ldm/pixeldit/modules.py
new file mode 100644
index 000000000..4b1e538c7
--- /dev/null
+++ b/comfy/ldm/pixeldit/modules.py
@@ -0,0 +1,187 @@
+import torch
+import torch.nn as nn
+
+from comfy.ldm.flux.math import apply_rope, rope
+from comfy.ldm.modules.attention import optimized_attention
+from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, get_1d_sincos_pos_embed_from_grid_torch
+
+
+def apply_adaln_(x, shift, scale):
+ return x.addcmul_(x, scale).add_(shift)
+
+
+def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0,
+ ref_grid_h=None, ref_grid_w=None,
+ scale_x=1.0, scale_y=1.0, shift_x=0.0, shift_y=0.0,
+ device=None, dtype=torch.float32, **kwargs):
+ """2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
+
+ rope_options:
+ scale_x / scale_y multiply the position range (RoPE extrapolation).
+ shift_x / shift_y offset the position origin (tiled / regional inference).
+ With ref_grid_h/w set, also applies NTK-aware per-axis theta scaling
+ (rope_mode='ntk_aware'): theta_axis = theta * (current/ref)^(dim_axis/(dim_axis-2)).
+ Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2].
+ Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}].
+ """
+ dim_axis = dim // 2
+ if ref_grid_h is not None and dim_axis > 2:
+ h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2))
+ w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2))
+ else:
+ h_ntk = w_ntk = 1.0
+
+ x_lin = torch.linspace(shift_x, scale * scale_x + shift_x, width, device=device)
+ y_lin = torch.linspace(shift_y, scale * scale_y + shift_y, height, device=device)
+ y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij")
+ x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0)
+ y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0)
+ out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2)
+ return out.to(dtype=dtype)
+
+
+def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32):
+ """Standard 2D sin/cos absolute positional embedding (ViT-style).
+
+ first half encodes W-coordinates, second half H.
+ """
+ assert embed_dim % 4 == 0
+ grid_h = torch.arange(height, dtype=torch.float32, device=device)
+ grid_w = torch.arange(width, dtype=torch.float32, device=device)
+ grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij")
+ emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_x.reshape(-1), device=device)
+ emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_y.reshape(-1), device=device)
+ return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype)
+
+
+class RotaryAttention(nn.Module):
+ """Single-stream self-attention with rotary positional encoding (used inside PiTBlock)."""
+ def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
+ super().__init__()
+ assert dim % num_heads == 0
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
+ self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
+ self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
+ self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
+
+ def forward(self, x, pos, mask=None, transformer_options={}):
+ B, N, C = x.shape
+ H = self.num_heads
+ D = self.head_dim
+ qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ q, k = apply_rope(self.q_norm(q), self.k_norm(k), pos[None, None])
+ x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options)
+ return self.proj(x)
+
+
+class FinalLayer(nn.Module):
+ def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
+ self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
+
+ def forward(self, x):
+ return self.linear(self.norm(x))
+
+
+class PatchTokenEmbedder(nn.Module):
+ """Linear projection used both for patchified-image tokens and text-feature tokens."""
+ def __init__(self, in_chans, embed_dim, use_norm=False, bias=True, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device)
+ self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) if use_norm else nn.Identity()
+
+ def forward(self, x):
+ return self.norm(self.proj(x))
+
+
+class PixelTokenEmbedder(nn.Module):
+ """Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences."""
+ def __init__(self, in_channels, hidden_size_output, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_size_output = hidden_size_output
+ self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device)
+
+ def forward(self, inputs, patch_size):
+ B, _, H, W = inputs.shape
+ Hs, Ws = H // patch_size, W // patch_size
+ P2 = patch_size * patch_size
+ x = inputs.permute(0, 2, 3, 1).contiguous()
+ x = self.proj(x)
+ pos_full = get_2d_sincos_pos_embed(self.hidden_size_output, H, W, device=x.device, dtype=x.dtype).view(H, W, self.hidden_size_output)
+ x = x + pos_full.unsqueeze(0)
+ x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output)
+ return x.permute(0, 1, 3, 2, 4, 5).reshape(B * Hs * Ws, P2, self.hidden_size_output)
+
+
+class PiTBlock(nn.Module):
+ """Pixel-level transformer block.
+
+ Compresses each patch's P^2 pixel tokens → 1 attention token via a linear,
+ runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens.
+ Conditioning is per-pixel adaLN from the patch-level features.
+ """
+ def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
+ attn_hidden_size=None, attn_num_heads=None, dtype=None, device=None, operations=None, mlp_chunks=1):
+ super().__init__()
+ self.pixel_dim = pixel_hidden_size
+ self.context_dim = patch_hidden_size
+ self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size
+ self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads
+ assert self.attn_dim % self.num_heads == 0
+
+ p2 = patch_size * patch_size
+ self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device)
+ self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device)
+
+ self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
+ self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, dtype=dtype, device=device, operations=operations)
+ self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
+ self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations)
+
+ self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
+ self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
+
+ self._rope_fn = precompute_freqs_cis_2d
+ self.mlp_chunks = max(1, int(mlp_chunks))
+
+ def _fetch_pos(self, height, width, device, dtype, **rope_opts):
+ return self._rope_fn(self.attn_dim // self.num_heads, height, width, device=device, dtype=dtype, **rope_opts)
+
+ def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
+ BL, P2, _ = x.shape
+ Hs, Ws = image_height // patch_size, image_width // patch_size
+ L = Hs * Ws
+ B = BL // L
+
+ # Attention path uses only msa params; compute, use, free before mlp params allocate.
+ msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
+ shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
+
+ x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa)
+ x_flat = x_norm.view(BL, P2 * self.pixel_dim)
+
+ x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
+ pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
+ attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options)
+ attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
+ attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
+ x = torch.addcmul(x, gate_msa, attn_exp)
+ del msa_params, shift_msa, scale_msa, gate_msa
+
+ mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
+ shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
+ gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP
+ mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp)
+ del mlp_params, shift_mlp, scale_mlp
+
+ # MLP in chunks since the peak memory usage is huge here
+ chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
+ for s in range(0, BL, chunk_size):
+ e = min(s + chunk_size, BL)
+ x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e]))
+ return x
diff --git a/comfy/ldm/pixeldit/pid.py b/comfy/ldm/pixeldit/pid.py
new file mode 100644
index 000000000..0ad4b7ce8
--- /dev/null
+++ b/comfy/ldm/pixeldit/pid.py
@@ -0,0 +1,226 @@
+"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent
+directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
+body + LQ projection branch injected before each MMDiT patch block.
+"""
+
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .model import PixDiT_T2I
+from .modules import precompute_freqs_cis_2d
+
+
+class SigmaAwareGatePerTokenPerDim(nn.Module):
+ """gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
+
+ Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1.
+ """
+
+ def __init__(self, dim: int, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device)
+ self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device))
+
+ def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ content_logit = self.content_proj(torch.cat([x, lq], dim=-1))
+ # log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
+ log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32)
+ sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1)
+ gate = torch.sigmoid(content_logit + sigma_offset)
+ return x + (gate * lq).to(x.dtype)
+
+
+class ResBlock(nn.Module):
+ """Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip."""
+
+ def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.block = nn.Sequential(
+ operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
+ nn.SiLU(),
+ operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
+ operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
+ nn.SiLU(),
+ operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x + self.block(x)
+
+
+class LQProjection2D(nn.Module):
+ """LQ latent -> per-block patch-aligned features for controlnet-style injection."""
+
+ def __init__(
+ self,
+ latent_channels: int,
+ hidden_dim: int = 512,
+ out_dim: int = 1536,
+ patch_size: int = 16,
+ sr_scale: int = 4,
+ latent_spatial_down_factor: int = 8,
+ num_res_blocks: int = 4,
+ num_outputs: int = 7,
+ interval: int = 2,
+ dtype=None, device=None, operations=None,
+ ):
+ super().__init__()
+ self.latent_channels = latent_channels
+ self.hidden_dim = hidden_dim
+ self.out_dim = out_dim
+ self.patch_size = patch_size
+ self.sr_scale = sr_scale
+ self.latent_spatial_down_factor = latent_spatial_down_factor
+ self.num_outputs = num_outputs
+ self.interval = interval
+
+ z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size
+ self.z_to_patch_ratio = z_to_patch_ratio
+ if z_to_patch_ratio >= 1:
+ self.latent_fold_factor = 0
+ latent_proj_in_ch = latent_channels
+ else:
+ fold_factor = int(1 / z_to_patch_ratio)
+ assert fold_factor * z_to_patch_ratio == 1.0
+ self.latent_fold_factor = fold_factor
+ latent_proj_in_ch = latent_channels * fold_factor * fold_factor
+
+ layers = [
+ operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
+ nn.SiLU(),
+ operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
+ ]
+ for _ in range(num_res_blocks):
+ layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations))
+ self.latent_proj = nn.Sequential(*layers)
+
+ self.output_heads = nn.ModuleList(
+ [operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)]
+ )
+ self.gate_modules = nn.ModuleList(
+ [SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations)
+ for _ in range(num_outputs)]
+ )
+
+ def is_gate_active(self, block_idx: int) -> bool:
+ return block_idx % self.interval == 0
+
+ def output_index(self, block_idx: int) -> int:
+ return block_idx // self.interval
+
+ def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor:
+ return self.gate_modules[out_idx](x, lq_feature, sigma)
+
+ def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor:
+ B, z_dim = lq_latent.shape[:2]
+ if self.z_to_patch_ratio >= 1:
+ if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW:
+ z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest")
+ else:
+ z_aligned = lq_latent
+ else:
+ f = self.latent_fold_factor
+ zH_expected, zW_expected = pH * f, pW * f
+ if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected:
+ lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest")
+ z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4)
+ z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW)
+ return self.latent_proj(z_aligned)
+
+ def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]:
+ feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW)
+ B, C, H, W = feat.shape
+ tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
+ return [head(tokens) for head in self.output_heads]
+
+
+class PidNet(PixDiT_T2I):
+ """PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block)."""
+
+ def __init__(
+ self,
+ lq_latent_channels: int = 16,
+ lq_hidden_dim: int = 512,
+ lq_num_res_blocks: int = 4,
+ lq_interval: int = 2,
+ sr_scale: int = 4,
+ latent_spatial_down_factor: int = 8,
+ rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64.
+ rope_ref_w: int = 1024,
+ image_model=None,
+ dtype=None, device=None, operations=None,
+ **pixdit_kwargs,
+ ):
+ super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs)
+
+ self.rope_ref_grid_h = rope_ref_h // self.patch_size
+ self.rope_ref_grid_w = rope_ref_w // self.patch_size
+
+ # Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware.
+ def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts):
+ return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts)
+ for blk in self.pixel_blocks:
+ blk._rope_fn = _pit_rope_fn
+
+ num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval
+ self.lq_proj = LQProjection2D(
+ latent_channels=lq_latent_channels,
+ hidden_dim=lq_hidden_dim,
+ out_dim=self.hidden_size,
+ patch_size=self.patch_size,
+ sr_scale=sr_scale,
+ latent_spatial_down_factor=latent_spatial_down_factor,
+ num_res_blocks=lq_num_res_blocks,
+ num_outputs=num_lq_outputs,
+ interval=lq_interval,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+
+ def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
+ return precompute_freqs_cis_2d(
+ self.hidden_size // self.num_groups,
+ height, width,
+ ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w,
+ device=device, dtype=dtype, **rope_opts,
+ )
+
+ def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs):
+ if not self.lq_proj.is_gate_active(i):
+ return s
+ out_idx = self.lq_proj.output_index(i)
+ if out_idx >= len(pid_lq_features):
+ return s
+ return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx)
+
+ def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs):
+ if lq_latent is None:
+ raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
+ expected_c = self.lq_proj.latent_channels
+ if lq_latent.shape[1] != expected_c:
+ raise ValueError(
+ f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. "
+ f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
+ )
+ B = x.shape[0]
+ Hs = x.shape[2] // self.patch_size
+ Ws = x.shape[3] // self.patch_size
+
+ degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
+ if degrade_sigma.numel() == 1 and B > 1:
+ degrade_sigma = degrade_sigma.expand(B).contiguous()
+
+ lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
+
+ return super()._forward(
+ x, timesteps,
+ context=context, attention_mask=attention_mask,
+ transformer_options=transformer_options,
+ pid_lq_features=lq_features,
+ pid_degrade_sigma=degrade_sigma,
+ **kwargs,
+ )
diff --git a/comfy/lora.py b/comfy/lora.py
index c0e8b865c..4e0ea29e0 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -16,7 +16,6 @@
along with this program. If not, see .
"""
-from __future__ import annotations
import comfy.memory_management
import comfy.utils
import comfy.model_management
diff --git a/comfy/memory_management.py b/comfy/memory_management.py
index c43f0c4a2..962addb27 100644
--- a/comfy/memory_management.py
+++ b/comfy/memory_management.py
@@ -1,6 +1,5 @@
import math
import ctypes
-import threading
import dataclasses
import torch
from typing import NamedTuple
@@ -10,7 +9,7 @@ from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple):
file_ref: object
- thread_id: int
+ lock: object
offset: int
size: int
@@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
- or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
@@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
device_ptr = destination2.data_ptr() if destination2 is not None else 0
- hostbuf.read_file_slice(file_obj, info.offset, info.size,
- offset=destination.data_ptr() - hostbuf.get_raw_address(),
- stream=stream_ptr,
- device_ptr=device_ptr,
- device=None if destination2 is None else destination2.device.index)
+ with info.lock:
+ hostbuf.read_file_slice(file_obj, info.offset, info.size,
+ offset=destination.data_ptr() - hostbuf.get_raw_address(),
+ stream=stream_ptr,
+ device_ptr=device_ptr,
+ device=None if destination2 is None else destination2.device.index)
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
try:
- file_obj.seek(info.offset)
- done = 0
- while done < info.size:
- try:
- n = file_obj.readinto(view[done:])
- except OSError:
- return False
- if n <= 0:
- return False
- done += n
+ with info.lock:
+ file_obj.seek(info.offset)
+ done = 0
+ while done < info.size:
+ try:
+ n = file_obj.readinto(view[done:])
+ except OSError:
+ return False
+ if n <= 0:
+ return False
+ done += n
return True
finally:
view.release()
diff --git a/comfy/model_base.py b/comfy/model_base.py
index d81f13c69..e55808633 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -35,6 +35,7 @@ import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
+import comfy.ldm.lens.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
@@ -48,6 +49,8 @@ import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.chroma_radiance.model
+import comfy.ldm.pixeldit.model
+import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
@@ -1058,6 +1061,27 @@ class Flux2(Flux):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
+
+class Lens(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
+ super().__init__(
+ model_config, model_type, device=device,
+ unet_model=comfy.ldm.lens.model.LensTransformer2DModel,
+ )
+
+ def encode_adm(self, **kwargs):
+ return None # Lens has no pooled/ADM conditioning.
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+ attention_mask = kwargs.get("attention_mask", None)
+ if attention_mask is not None:
+ out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
+ return out
+
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
@@ -1375,6 +1399,36 @@ class ZImagePixelSpace(Lumina2):
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
self.memory_usage_factor_conds = ("ref_latents",)
+
+class PixelDiTT2I(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device,
+ unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ attention_mask = kwargs.get("attention_mask", None)
+ if attention_mask is not None:
+ out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
+ return out
+
+
+class PiD(PixelDiTT2I):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ BaseModel.__init__(self, model_config, model_type, device=device,
+ unet_model=comfy.ldm.pixeldit.pid.PidNet)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ lq_latent = kwargs.get("lq_latent", None)
+ if lq_latent is not None:
+ out["lq_latent"] = comfy.conds.CONDRegular(lq_latent)
+ degrade_sigma = kwargs.get("degrade_sigma", None)
+ if degrade_sigma is not None:
+ out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
+ return out
+
+
class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 70b4df8b3..f0db7d388 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -463,6 +463,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
+ # PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I.
+ _lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix)
+ if _lq_w_key in state_dict_keys:
+ in_ch = int(state_dict[_lq_w_key].shape[1])
+ _gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix)
+ num_gates = len({k[len(_gate_prefix):].split('.')[0]
+ for k in state_dict_keys if k.startswith(_gate_prefix)})
+ dit_config = {"image_model": "pid",
+ "lq_latent_channels": in_ch,
+ "latent_spatial_down_factor": 16 if in_ch >= 64 else 8}
+ if num_gates > 0:
+ dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates
+ return dit_config
+
+ if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I
+ return {"image_model": "pixeldit_t2i"}
+
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
dit_config = {}
dit_config["image_model"] = "lumina2"
@@ -755,6 +772,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["timestep_scale"] = 1000.0
return dit_config
+ if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \
+ and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens
+ img_in_w = state_dict['{}img_in.weight'.format(key_prefix)]
+ proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)]
+ multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys
+ if multi_layer:
+ enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0]
+ # Indices are TE-side; the DiT just consumes L layers in order.
+ selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.')))
+ else:
+ enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0]
+ selected_layer_index = (0,)
+
+ return {
+ "image_model": "lens",
+ "in_channels": img_in_w.shape[1],
+ "out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default)
+ "num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'),
+ "num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default
+ "enc_hidden_dim": enc_hidden_dim,
+ "multi_layer_encoder_feature": multi_layer,
+ "selected_layer_index": selected_layer_index,
+ }
+
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"
diff --git a/comfy/model_management.py b/comfy/model_management.py
index cd8772d3a..b01c4d7fa 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -15,6 +15,7 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see .
"""
+from __future__ import annotations
import psutil
import logging
@@ -27,13 +28,18 @@ import platform
import weakref
import gc
import os
-from contextlib import nullcontext
+from contextlib import contextmanager, nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.host_buffer
import comfy_aimdo.vram_buffer
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from comfy.model_patcher import ModelPatcher
+
+
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@@ -204,6 +210,107 @@ def get_torch_device():
else:
return torch.device(torch.cuda.current_device())
+def get_all_torch_devices(exclude_current=False):
+ global cpu_state
+ devices = []
+ if cpu_state == CPUState.GPU:
+ # NVIDIA + AMD/ROCm both expose their GPUs through torch.cuda.*;
+ # without the AMD arm, single-GPU ROCm users get an empty list
+ # which silently turns unload_all_models() into a no-op.
+ if is_nvidia() or is_amd():
+ for i in range(torch.cuda.device_count()):
+ devices.append(torch.device("cuda", i))
+ elif is_intel_xpu():
+ for i in range(torch.xpu.device_count()):
+ devices.append(torch.device("xpu", i))
+ elif is_ascend_npu():
+ for i in range(torch.npu.device_count()):
+ devices.append(torch.device("npu", i))
+ elif is_mlu():
+ for i in range(torch.mlu.device_count()):
+ devices.append(torch.device("mlu", i))
+ else:
+ # Fallback for unhandled GPU backends (e.g. DirectML): at least
+ # report the current device so callers like unload_all_models()
+ # do not silently no-op.
+ devices.append(get_torch_device())
+ else:
+ devices.append(get_torch_device())
+ if exclude_current:
+ current = get_torch_device()
+ if current in devices:
+ devices.remove(current)
+ return devices
+
+def get_gpu_device_options():
+ """Return list of device option strings for node widgets.
+
+ Always includes "default" and "cpu". When multiple GPUs are present,
+ adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
+ """
+ options = ["default", "cpu"]
+ devices = get_all_torch_devices()
+ if len(devices) > 1:
+ for i in range(len(devices)):
+ options.append(f"gpu:{i}")
+ return options
+
+def get_gpu_device_options_no_cpu():
+ """Variant of get_gpu_device_options that omits "cpu".
+
+ Intended for components like the VAE selector where running on CPU
+ is impractical and should not be offered as a choice.
+ """
+ return [o for o in get_gpu_device_options() if o != "cpu"]
+
+def resolve_gpu_device_option(option: str):
+ """Resolve a device option string to a torch.device.
+
+ Returns None for "default" (let the caller use its normal default).
+ Returns torch.device("cpu") for "cpu".
+ For "gpu:N", returns the Nth torch device. Returns None if the
+ index is out of range, the option string is malformed, or
+ unrecognized (callers are expected to log their own context-rich
+ message before falling back to the default device).
+ """
+ if option is None or option == "default":
+ return None
+ if option == "cpu":
+ return torch.device("cpu")
+ if option.startswith("gpu:"):
+ try:
+ idx = int(option[4:])
+ except ValueError:
+ return None
+ devices = get_all_torch_devices()
+ if 0 <= idx < len(devices):
+ return devices[idx]
+ return None
+
+@contextmanager
+def cuda_device_context(device):
+ """Context manager that sets torch.cuda.current_device to match *device*.
+
+ Used when running operations on a non-default CUDA device so that custom
+ CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
+ device index. The previous device is restored on exit.
+
+ No-op when *device* is not CUDA, has no explicit index, or already matches
+ the current device.
+ """
+ prev = None
+ if device.type == "cuda" and device.index is not None:
+ prev = torch.cuda.current_device()
+ if prev != device.index:
+ torch.cuda.set_device(device)
+ else:
+ prev = None
+ try:
+ yield
+ finally:
+ if prev is not None:
+ torch.cuda.set_device(prev)
+
def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled
if dev is None:
@@ -492,9 +599,13 @@ try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
logging.warning("Could not pick default device.")
+try:
+ for device in get_all_torch_devices(exclude_current=True):
+ logging.info("Device: {}".format(get_torch_device_name(device)))
+except:
+ pass
-
-current_loaded_models = []
+current_loaded_models: list[LoadedModel] = []
DIRTY_MMAPS = set()
@@ -554,7 +665,7 @@ def ensure_pin_registerable(size, evict_active=False):
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
class LoadedModel:
- def __init__(self, model):
+ def __init__(self, model: ModelPatcher):
self._set_model(model)
self.device = model.load_device
self.real_model = None
@@ -562,7 +673,7 @@ class LoadedModel:
self.model_finalizer = None
self._patcher_finalizer = None
- def _set_model(self, model):
+ def _set_model(self, model: ModelPatcher):
self._model = weakref.ref(model)
if model.parent is not None:
self._parent_model = weakref.ref(model.parent)
@@ -573,6 +684,7 @@ class LoadedModel:
model = self._parent_model()
if model is not None:
self._set_model(model)
+ self.device = model.load_device
@property
def model(self):
@@ -1848,7 +1960,34 @@ def soft_empty_cache(force=False):
torch.cuda.ipc_collect()
def unload_all_models():
- free_memory(1e30, get_torch_device())
+ for device in get_all_torch_devices():
+ free_memory(1e30, device)
+
+def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
+ 'Unload only model and its clones - primarily for multigpu cloning purposes.'
+ initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
+ additional_models = []
+ if unload_additional_models:
+ additional_models = model.get_nested_additional_models()
+ keep_loaded = []
+ for loaded_model in initial_keep_loaded:
+ if loaded_model.model is not None:
+ if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
+ continue
+ # check additional models if they are a match
+ skip = False
+ for add_model in additional_models:
+ if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
+ skip = True
+ break
+ if skip:
+ continue
+ keep_loaded.append(loaded_model)
+ if not all_devices:
+ free_memory(1e30, get_torch_device(), keep_loaded)
+ else:
+ for device in get_all_torch_devices():
+ free_memory(1e30, device, keep_loaded)
def debug_memory_summary():
if is_amd() or is_nvidia():
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index b44b99e4a..00a15fa63 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -78,12 +78,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
-def create_hook_patches_clone(orig_hook_patches):
+def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
new_hook_patches = {}
for hook_ref in orig_hook_patches:
new_hook_patches[hook_ref] = {}
for k in orig_hook_patches[hook_ref]:
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
+ if copy_tuples:
+ for i in range(len(new_hook_patches[hook_ref][k])):
+ new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
return new_hook_patches
def wipe_lowvram_weight(m):
@@ -329,7 +332,10 @@ class ModelPatcher:
self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
- self.cached_patcher_init: tuple[Callable, tuple] | None = None
+ self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
+ self.is_multigpu_base_clone = False
+ self.clone_base_uuid = uuid.uuid4()
+
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
@@ -366,7 +372,8 @@ class ModelPatcher:
#than pays for CFG. So return everything both torch and Aimdo could give us
aimdo_mem = 0
if comfy.memory_management.aimdo_enabled:
- aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
+ aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None
+ aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device)
return comfy.model_management.get_free_memory(device) + aimdo_mem
def get_clone_model_override(self):
@@ -380,6 +387,8 @@ class ModelPatcher:
if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
+ if len(self.cached_patcher_init) > 2:
+ temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
model_override = temp_model_patcher.get_clone_model_override()
if model_override is None:
model_override = self.get_clone_model_override()
@@ -438,19 +447,113 @@ class ModelPatcher:
n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init
+ n.is_multigpu_base_clone = self.is_multigpu_base_clone
+ n.clone_base_uuid = self.clone_base_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n)
return n
+ def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
+ logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
+ if self.cached_patcher_init is None:
+ raise RuntimeError(
+ f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: "
+ "the loader that produced this model does not support multigpu "
+ "(cached_patcher_init is not initialized). Use a core loader "
+ "(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), "
+ "or have the custom loader register a cached_patcher_init factory."
+ )
+ comfy.model_management.unload_model_and_clones(self)
+ # Produce a freshly-loaded patcher from the loader factory so the multigpu
+ # clone owns its own untainted model weights (rather than relying on
+ # copy.deepcopy of an already-patched/already-loaded module).
+ temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
+ if len(self.cached_patcher_init) > 2:
+ temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
+ # Override clone()'s normal "share self.model + share backup containers" with
+ # the pristine model from temp_model_patcher plus empty backup containers --
+ # the fresh model has no patches applied, so any deepcopy of self's stale
+ # backup/object_patches_backup/pinned would just propagate dead state that
+ # no longer corresponds to anything in n.model.
+ model_override = (temp_model_patcher.model, ({}, {}, {}, set()))
+ n = self.clone(model_override=model_override)
+ # clone() copies hook_backup by reference from self; reset since model is pristine.
+ n.hook_backup = {}
+ # set load device, if present
+ if new_load_device is not None:
+ n.load_device = new_load_device
+ # Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins)
+ # has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's
+ # __init__ only registered its own (default) load_device.
+ if hasattr(n, "register_load_device"):
+ n.register_load_device(n.load_device)
+ # multigpu clone should not have multigpu additional_models entry
+ n.remove_additional_models("multigpu")
+ # multigpu_clone all stored additional_models; make sure circular references are properly handled
+ if models_cache is None:
+ models_cache = {}
+ for key, model_list in n.additional_models.items():
+ for i in range(len(model_list)):
+ add_model = n.additional_models[key][i]
+ if add_model.clone_base_uuid not in models_cache:
+ models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
+ n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
+ for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
+ callback(self, n)
+ return n
+
+ def match_multigpu_clones(self):
+ multigpu_models = self.get_additional_models_with_key("multigpu")
+ if len(multigpu_models) > 0:
+ new_multigpu_models = []
+ for mm in multigpu_models:
+ # clone main model, but bring over relevant props from existing multigpu clone
+ n = self.clone()
+ n.load_device = mm.load_device
+ n.backup = mm.backup
+ n.object_patches_backup = mm.object_patches_backup
+ n.hook_backup = mm.hook_backup
+ n.model = mm.model
+ n.is_multigpu_base_clone = mm.is_multigpu_base_clone
+ n.remove_additional_models("multigpu")
+ orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
+ n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
+ # figure out which additional models are not present in multigpu clone
+ models_cache = {}
+ for mm_add_model in mm.get_additional_models():
+ models_cache[mm_add_model.clone_base_uuid] = mm_add_model
+ remove_models_uuids = set(list(models_cache.keys()))
+ for key, model_list in orig_additional_models.items():
+ for orig_add_model in model_list:
+ if orig_add_model.clone_base_uuid not in models_cache:
+ models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
+ existing_list = n.get_additional_models_with_key(key)
+ existing_list.append(models_cache[orig_add_model.clone_base_uuid])
+ n.set_additional_models(key, existing_list)
+ if orig_add_model.clone_base_uuid in remove_models_uuids:
+ remove_models_uuids.remove(orig_add_model.clone_base_uuid)
+ # remove duplicate additional models
+ for key, model_list in n.additional_models.items():
+ new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
+ n.set_additional_models(key, new_model_list)
+ for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
+ callback(self, n)
+ new_multigpu_models.append(n)
+ self.set_additional_models("multigpu", new_multigpu_models)
+
def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model:
return True
return False
- def clone_has_same_weights(self, clone: 'ModelPatcher'):
- if not self.is_clone(clone):
- return False
+ def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
+ if allow_multigpu:
+ if self.clone_base_uuid != clone.clone_base_uuid:
+ return False
+ else:
+ if not self.is_clone(clone):
+ return False
if self.current_hooks != clone.current_hooks:
return False
@@ -1232,7 +1335,7 @@ class ModelPatcher:
return self.additional_models.get(key, [])
def get_additional_models(self):
- all_models = []
+ all_models: list[ModelPatcher] = []
for models in self.additional_models.values():
all_models.extend(models)
return all_models
@@ -1286,9 +1389,18 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self)
- def prepare_state(self, timestep):
+ def prepare_state(self, timestep, model_options):
+ ignore_multigpu = model_options.get("ignore_multigpu", False)
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
- callback(self, timestep)
+ callback(self, timestep, model_options)
+ if not ignore_multigpu and "multigpu_clones" in model_options:
+ model_options["ignore_multigpu"] = True
+ try:
+ for p in model_options["multigpu_clones"].values():
+ p: ModelPatcher
+ p.prepare_state(timestep, model_options)
+ finally:
+ model_options.pop("ignore_multigpu", None)
def restore_hook_patches(self):
if self.hook_patches_backup is not None:
@@ -1301,12 +1413,18 @@ class ModelPatcher:
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0]
reset_current_hooks = False
+ multigpu_kf_changed_cache = None
transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling
if changed:
+ # cache changed for multigpu usage
+ if "multigpu_clones" in model_options:
+ if multigpu_kf_changed_cache is None:
+ multigpu_kf_changed_cache = []
+ multigpu_kf_changed_cache.append(hook)
# reset current_hooks if contains hook that changed
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
@@ -1318,6 +1436,28 @@ class ModelPatcher:
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
+ if "multigpu_clones" in model_options:
+ for p in model_options["multigpu_clones"].values():
+ p: ModelPatcher
+ p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
+
+ def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
+ 'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
+ if kf_changed_cache is None:
+ return
+ reset_current_hooks = False
+ # reset current_hooks if contains hook that changed
+ for hook in kf_changed_cache:
+ if self.current_hooks is not None:
+ for current_hook in self.current_hooks.hooks:
+ if current_hook == hook:
+ reset_current_hooks = True
+ break
+ for cached_group in list(self.cached_hook_patches.keys()):
+ if cached_group.contains(hook):
+ self.cached_hook_patches.pop(cached_group)
+ if reset_current_hooks:
+ self.patch_hooks(None)
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None):
@@ -1566,16 +1706,27 @@ class ModelPatcherDynamic(ModelPatcher):
self.model.dynamic_vbars = {}
if not hasattr(self.model, "dynamic_pins"):
self.model.dynamic_pins = {}
- if self.load_device not in self.model.dynamic_pins:
- self.model.dynamic_pins[self.load_device] = {
+ self.register_load_device(self.load_device)
+ self.non_dynamic_delegate_model = None
+ assert load_device is not None
+
+ def register_load_device(self, device):
+ """Ensure dynamic_pins has an entry for *device*.
+
+ Called from __init__ and also from any code that retargets an
+ already-constructed patcher to a new load_device (e.g. the
+ Select{Model,CLIP,VAE}Device selector nodes); without this entry
+ partially_unload_ram() raises KeyError when it tries to read the
+ per-device pin state.
+ """
+ if device not in self.model.dynamic_pins:
+ self.model.dynamic_pins[device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"hostbufs_initialized": False,
"failed": False,
"active": False,
}
- self.non_dynamic_delegate_model = None
- assert load_device is not None
def is_dynamic(self):
return True
diff --git a/comfy/multigpu.py b/comfy/multigpu.py
new file mode 100644
index 000000000..e7f5b3d6f
--- /dev/null
+++ b/comfy/multigpu.py
@@ -0,0 +1,248 @@
+from __future__ import annotations
+import queue
+import threading
+import torch
+import logging
+
+from collections import namedtuple
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from comfy.model_patcher import ModelPatcher
+import comfy.utils
+import comfy.patcher_extension
+import comfy.model_management
+
+
+class MultiGPUThreadPool:
+ """Persistent thread pool for multi-GPU work distribution.
+
+ Maintains one worker thread per extra GPU device. Each thread calls
+ torch.cuda.set_device() once at startup so that compiled kernel caches
+ (inductor/triton) stay warm across diffusion steps.
+ """
+
+ def __init__(self, devices: list[torch.device]):
+ self._workers: list[threading.Thread] = []
+ self._work_queues: dict[torch.device, queue.Queue] = {}
+ self._result_queues: dict[torch.device, queue.Queue] = {}
+
+ for device in devices:
+ wq = queue.Queue()
+ rq = queue.Queue()
+ self._work_queues[device] = wq
+ self._result_queues[device] = rq
+ t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
+ t.start()
+ self._workers.append(t)
+
+ def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
+ try:
+ torch.cuda.set_device(device)
+ except Exception as e:
+ logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
+ while True:
+ item = work_q.get()
+ if item is None:
+ return
+ result_q.put((None, e))
+ return
+ while True:
+ item = work_q.get()
+ if item is None:
+ break
+ fn, args, kwargs = item
+ try:
+ result = fn(*args, **kwargs)
+ result_q.put((result, None))
+ except Exception as e:
+ result_q.put((None, e))
+
+ def submit(self, device: torch.device, fn, *args, **kwargs):
+ self._work_queues[device].put((fn, args, kwargs))
+
+ def get_result(self, device: torch.device):
+ return self._result_queues[device].get()
+
+ @property
+ def devices(self) -> list[torch.device]:
+ return list(self._work_queues.keys())
+
+ def shutdown(self):
+ for wq in self._work_queues.values():
+ wq.put(None) # sentinel
+ for t in self._workers:
+ t.join(timeout=5.0)
+
+
+class GPUOptions:
+ def __init__(self, device_index: int, relative_speed: float):
+ self.device_index = device_index
+ self.relative_speed = relative_speed
+
+ def clone(self):
+ return GPUOptions(self.device_index, self.relative_speed)
+
+ def create_dict(self):
+ return {
+ "relative_speed": self.relative_speed
+ }
+
+class GPUOptionsGroup:
+ def __init__(self):
+ self.options: dict[int, GPUOptions] = {}
+
+ def add(self, info: GPUOptions):
+ self.options[info.device_index] = info
+
+ def clone(self):
+ c = GPUOptionsGroup()
+ for opt in self.options.values():
+ c.add(opt)
+ return c
+
+ def register(self, model: ModelPatcher):
+ opts_dict = {}
+ # get devices that are valid for this model
+ devices: list[torch.device] = [model.load_device]
+ for extra_model in model.get_additional_models_with_key("multigpu"):
+ extra_model: ModelPatcher
+ devices.append(extra_model.load_device)
+ # create dictionary with actual device mapped to its GPUOptions
+ device_opts_list: list[GPUOptions] = []
+ for device in devices:
+ device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
+ opts_dict[device] = device_opts.create_dict()
+ device_opts_list.append(device_opts)
+ # make relative_speed relative to 1.0
+ min_speed = min([x.relative_speed for x in device_opts_list])
+ for value in opts_dict.values():
+ value['relative_speed'] /= min_speed
+ model.model_options['multigpu_options'] = opts_dict
+
+
+def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
+ 'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
+ model = model.clone()
+ # check if multigpu is already prepared - get the load devices from them if possible to exclude
+ skip_devices = set()
+ multigpu_models = model.get_additional_models_with_key("multigpu")
+ if len(multigpu_models) > 0:
+ for mm in multigpu_models:
+ skip_devices.add(mm.load_device)
+ skip_devices = list(skip_devices)
+
+ # Exclude the primary model's actual device, not the global current device:
+ # after SelectModelDevice(gpu:N) the primary may not live on the process's
+ # current CUDA device, and excluding the wrong device picks bad extras.
+ all_devices = comfy.model_management.get_all_torch_devices(exclude_current=False)
+ full_extra_devices = [d for d in all_devices if d != model.load_device]
+ limit_extra_devices = full_extra_devices[:max_gpus-1]
+ extra_devices = limit_extra_devices.copy()
+ # exclude skipped devices
+ for skip in skip_devices:
+ if skip in extra_devices:
+ extra_devices.remove(skip)
+ # create new deepclones
+ if len(extra_devices) > 0:
+ for device in extra_devices:
+ device_patcher = None
+ if reuse_loaded:
+ # Only reuse a previously-loaded MultiGPU clone. A SelectModelDevice
+ # patcher on the same device shares clone_base_uuid but has
+ # is_multigpu_base_clone=False, which would later be filtered out by
+ # prepare_model_patcher_multigpu_clones() and silently shrink the
+ # work split back to one GPU.
+ loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
+ for lm in loaded_models:
+ if lm.model is None:
+ continue
+ if lm.load_device != device:
+ continue
+ if lm.clone_base_uuid != model.clone_base_uuid:
+ continue
+ if not getattr(lm, "is_multigpu_base_clone", False):
+ continue
+ device_patcher = lm.clone()
+ logging.info(f"Reusing loaded multigpu deepclone of {device_patcher.model.__class__.__name__} for {device}")
+ break
+ if device_patcher is None:
+ device_patcher = model.deepclone_multigpu(new_load_device=device)
+ # Always flag the clone; whether reused or freshly deepcloned, it must
+ # advertise itself as a MultiGPU base clone so the cond scheduler picks
+ # it up in prepare_model_patcher_multigpu_clones().
+ device_patcher.is_multigpu_base_clone = True
+ multigpu_models = model.get_additional_models_with_key("multigpu")
+ multigpu_models.append(device_patcher)
+ model.set_additional_models("multigpu", multigpu_models)
+ model.match_multigpu_clones()
+ if gpu_options is None:
+ gpu_options = GPUOptionsGroup()
+ gpu_options.register(model)
+ else:
+ logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
+ # only keep model clones that don't go 'past' the intended max_gpu count;
+ # this prunes any inherited multigpu clones whose load_device is no longer allowed
+ # when max_gpus is lowered between runs.
+ allowed_devices = set(limit_extra_devices)
+ allowed_devices.add(model.load_device)
+ multigpu_models = model.get_additional_models_with_key("multigpu")
+ new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices]
+ if len(new_multigpu_models) != len(multigpu_models):
+ model.set_additional_models("multigpu", new_multigpu_models)
+ model.match_multigpu_clones()
+ return model
+
+
+LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
+def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
+ 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
+ opts_dict = model_options['multigpu_options']
+ devices = list(model_options['multigpu_clones'].keys())
+ speed_per_device = []
+ work_per_device = []
+ # get sum of each device's relative_speed
+ total_speed = 0.0
+ for opts in opts_dict.values():
+ total_speed += opts['relative_speed']
+ # get relative work for each device;
+ # obtained by w = (W*r)/R
+ for device in devices:
+ relative_speed = opts_dict[device]['relative_speed']
+ relative_work = (total_work*relative_speed) / total_speed
+ speed_per_device.append(relative_speed)
+ work_per_device.append(relative_work)
+ # relative work must be expressed in whole numbers, but likely is a decimal;
+ # perform rounding while maintaining total sum equal to total work (sum of relative works)
+ work_per_device = round_preserved(work_per_device)
+ dict_work_per_device = {}
+ for device, relative_work in zip(devices, work_per_device):
+ dict_work_per_device[device] = relative_work
+ if not return_idle_time:
+ return LoadBalance(dict_work_per_device, None)
+ # divide relative work by relative speed to get estimated completion time of said work by each device;
+ # time here is relative and does not correspond to real-world units
+ completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
+ # calculate relative time spent by the devices waiting on each other after their work is completed
+ idle_time = abs(min(completion_time) - max(completion_time))
+ # if need to compare work idle time, need to normalize to a common total work
+ if work_normalized:
+ idle_time *= (work_normalized/total_work)
+
+ return LoadBalance(dict_work_per_device, idle_time)
+
+def round_preserved(values: list[float]):
+ 'Round all values in a list, preserving the combined sum of values.'
+ # get floor of values; casting to int does it too
+ floored = [int(x) for x in values]
+ total_floored = sum(floored)
+ # get remainder to distribute
+ remainder = round(sum(values)) - total_floored
+ # pair values with fractional portions
+ fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
+ # sort by fractional part in descending order
+ fractional.sort(key=lambda x: x[1], reverse=True)
+ # distribute the remainder
+ for i in range(remainder):
+ index = fractional[i][0]
+ floored[index] += 1
+ return floored
diff --git a/comfy/ops.py b/comfy/ops.py
index 9bcd6c900..56445be8d 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -18,6 +18,7 @@
import torch
import logging
+import contextlib
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
@@ -1047,6 +1048,144 @@ class QuantLinearFunc(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None
+# Quantized-weight module helpers
+
+def _quantized_apply(module, fn, recurse=True):
+ """Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights."""
+ if recurse:
+ for child in module.children():
+ child._apply(fn)
+ for key, param in module._parameters.items():
+ if param is None:
+ continue
+ p = fn(param)
+ if (not torch.is_inference_mode_enabled()) and p.is_inference():
+ p = p.clone()
+ module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
+ for key, buf in module._buffers.items():
+ if buf is not None:
+ module._buffers[key] = fn(buf)
+ return module
+
+
+def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs, load_extra_params=False):
+ """Shared _load_from_state_dict body for quantized-weight modules.
+
+ Pops weight (+ scales, +/- extras), populates module.weight as a Parameter
+ or Parameter-wrapped QuantizedTensor, then calls super_load and strips
+ consumed keys from missing_keys. Reads compute_dtype from factory_kwargs
+ and disabled formats from module._disabled_formats.
+ """
+ device = module.factory_kwargs["device"]
+ compute_dtype = module.factory_kwargs["dtype"]
+ disabled_formats = module._disabled_formats
+ layer_name = prefix.rstrip('.')
+
+ weight = state_dict.pop(f"{prefix}weight", None)
+ if weight is None:
+ logging.warning(f"Missing weight for layer {layer_name}")
+ module.weight = None
+ return
+ manually_loaded_keys = [f"{prefix}weight"]
+
+ def pop_scale(name, dtype=None):
+ key = f"{prefix}{name}"
+ v = state_dict.pop(key, None)
+ if v is not None:
+ v = v.to(device=device)
+ if dtype is not None:
+ v = v.view(dtype=dtype)
+ manually_loaded_keys.append(key)
+ return v
+
+ layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
+ if layer_conf is not None:
+ layer_conf = json.loads(layer_conf.numpy().tobytes())
+
+ if layer_conf is None:
+ module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False)
+ else:
+ module.quant_format = layer_conf.get("format", None)
+ module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
+ if not module._full_precision_mm:
+ module._full_precision_mm = module._full_precision_mm_config
+ if module.quant_format in disabled_formats:
+ module._full_precision_mm = True
+ if module.quant_format is None:
+ raise ValueError(f"Unknown quantization format for layer {layer_name}")
+
+ qconfig = QUANT_ALGOS[module.quant_format]
+ module.layout_type = qconfig["comfy_tensor_layout"]
+ layout_cls = get_layout_class(module.layout_type)
+
+ # Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8.
+ if module.quant_format in ("float8_e4m3fn", "float8_e5m2"):
+ scales = {"scale": pop_scale("weight_scale")}
+ elif module.quant_format == "mxfp8":
+ bs = pop_scale("weight_scale", torch.float8_e8m0fnu)
+ if bs is None:
+ raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
+ scales = {"scale": bs}
+ elif module.quant_format == "nvfp4":
+ ts = pop_scale("weight_scale_2")
+ bs = pop_scale("weight_scale", torch.float8_e4m3fn)
+ if ts is None or bs is None:
+ raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
+ scales = {"scale": ts, "block_scale": bs}
+ else:
+ raise ValueError(f"Unsupported quantization format: {module.quant_format}")
+
+ params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape)
+ module.weight = torch.nn.Parameter(
+ QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params),
+ requires_grad=False,
+ )
+
+ if load_extra_params:
+ for param_name in qconfig["parameters"]:
+ if param_name in {"weight_scale", "weight_scale_2"}:
+ continue
+ param_key = f"{prefix}{param_name}"
+ _v = state_dict.pop(param_key, None)
+ if _v is None:
+ continue
+ module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
+ manually_loaded_keys.append(param_key)
+
+ super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+ for key in manually_loaded_keys:
+ if key in missing_keys:
+ missing_keys.remove(key)
+
+
+def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()):
+ """Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON;
+ extra_quant_params names attributes written as additional top-level keys."""
+ if not hasattr(module, 'weight'):
+ logging.warning(f"Warning: state dict on uninitialized op {prefix}")
+ return sd
+ bias = getattr(module, 'bias', None)
+ if bias is not None:
+ sd[f"{prefix}bias"] = bias
+ if module.weight is None:
+ return sd
+ if isinstance(module.weight, QuantizedTensor):
+ sd.update(module.weight.state_dict(f"{prefix}weight"))
+ quant_conf = {"format": module.quant_format}
+ if getattr(module, '_full_precision_mm_config', False):
+ quant_conf["full_precision_matrix_mult"] = True
+ if extra_quant_conf:
+ quant_conf.update(extra_quant_conf)
+ sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
+ for name in extra_quant_params:
+ value = getattr(module, name, None)
+ if value is not None:
+ sd[f"{prefix}{name}"] = value
+ else:
+ sd[f"{prefix}weight"] = module.weight
+ return sd
+
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
@@ -1056,21 +1195,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
_disabled = disabled
class Linear(torch.nn.Module, CastWeightBiasOp):
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool = True,
- device=None,
- dtype=None,
- ) -> None:
+ _disabled_formats = disabled
+
+ def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
- # self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
+ self._orig_shape = (out_features, in_features)
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
@@ -1083,151 +1217,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def reset_parameters(self):
return None
- def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
- key = f"{prefix}{param_name}"
- value = state_dict.pop(key, None)
- if value is not None:
- value = value.to(device=device)
- if dtype is not None:
- value = value.view(dtype=dtype)
- manually_loaded_keys.append(key)
- return value
-
- def _load_from_state_dict(self, state_dict, prefix, local_metadata,
- strict, missing_keys, unexpected_keys, error_msgs):
-
- device = self.factory_kwargs["device"]
- layer_name = prefix.rstrip('.')
- weight_key = f"{prefix}weight"
- weight = state_dict.pop(weight_key, None)
- if weight is None:
- logging.warning(f"Missing weight for layer {layer_name}")
- self.weight = None
- return
-
- manually_loaded_keys = [weight_key]
-
- layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
- if layer_conf is not None:
- layer_conf = json.loads(layer_conf.numpy().tobytes())
-
- if layer_conf is None:
- self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
- else:
- self.quant_format = layer_conf.get("format", None)
- self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
- if not self._full_precision_mm:
- self._full_precision_mm = self._full_precision_mm_config
-
- if self.quant_format in MixedPrecisionOps._disabled:
- self._full_precision_mm = True
-
- if self.quant_format is None:
- raise ValueError(f"Unknown quantization format for layer {layer_name}")
-
- qconfig = QUANT_ALGOS[self.quant_format]
- self.layout_type = qconfig["comfy_tensor_layout"]
- layout_cls = get_layout_class(self.layout_type)
-
- # Load format-specific parameters
- if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
- # FP8: single tensor scale
- scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
-
- params = layout_cls.Params(
- scale=scale,
- orig_dtype=MixedPrecisionOps._compute_dtype,
- orig_shape=(self.out_features, self.in_features),
- )
-
- elif self.quant_format == "mxfp8":
- # MXFP8: E8M0 block scales stored as uint8 in safetensors
- block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
- dtype=torch.uint8)
-
- if block_scale is None:
- raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
-
- block_scale = block_scale.view(torch.float8_e8m0fnu)
-
- params = layout_cls.Params(
- scale=block_scale,
- orig_dtype=MixedPrecisionOps._compute_dtype,
- orig_shape=(self.out_features, self.in_features),
- )
-
- elif self.quant_format == "nvfp4":
- # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
- tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
- block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
- dtype=torch.float8_e4m3fn)
-
- if tensor_scale is None or block_scale is None:
- raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
-
- params = layout_cls.Params(
- scale=tensor_scale,
- block_scale=block_scale,
- orig_dtype=MixedPrecisionOps._compute_dtype,
- orig_shape=(self.out_features, self.in_features),
- )
- else:
- raise ValueError(f"Unsupported quantization format: {self.quant_format}")
-
- self.weight = torch.nn.Parameter(
- QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
- requires_grad=False
- )
-
- for param_name in qconfig["parameters"]:
- if param_name in {"weight_scale", "weight_scale_2"}:
- continue # Already handled above
-
- param_key = f"{prefix}{param_name}"
- _v = state_dict.pop(param_key, None)
- if _v is None:
- continue
- self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
- manually_loaded_keys.append(param_key)
-
- super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
-
- for key in manually_loaded_keys:
- if key in missing_keys:
- missing_keys.remove(key)
+ def _load_from_state_dict(self, *args):
+ _load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
- if destination is not None:
- sd = destination
- else:
- sd = {}
-
- if not hasattr(self, 'weight'):
- logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
- return sd
-
- if self.bias is not None:
- sd["{}bias".format(prefix)] = self.bias
-
- if self.weight is None:
- return sd
-
- if isinstance(self.weight, QuantizedTensor):
- sd_out = self.weight.state_dict("{}weight".format(prefix))
- for k in sd_out:
- sd[k] = sd_out[k]
-
- quant_conf = {"format": self.quant_format}
- if self._full_precision_mm_config:
- quant_conf["full_precision_matrix_mult"] = True
- sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
-
- input_scale = getattr(self, 'input_scale', None)
- if input_scale is not None:
- sd["{}input_scale".format(prefix)] = input_scale
- else:
- sd["{}weight".format(prefix)] = self.weight
- return sd
+ sd = destination if destination is not None else {}
+ return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",))
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
@@ -1317,25 +1312,126 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(weight, requires_grad=False)
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
- if recurse:
- for module in self.children():
- module._apply(fn)
+ return _quantized_apply(self, fn, recurse)
- for key, param in self._parameters.items():
- if param is None:
- continue
- p = fn(param)
- if (not torch.is_inference_mode_enabled()) and p.is_inference():
- p = p.clone()
- self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
- for key, buf in self._buffers.items():
- if buf is not None:
- self._buffers[key] = fn(buf)
- return self
+ class MoEExperts(torch.nn.Module, CastWeightBiasOp):
+ """Container for E quantized expert weights, indexed via expert_weight(i).
+
+ The bank lives on self.weight as a single 3D tensor — either a
+ compute_dtype Parameter or a Parameter wrapping a QuantizedTensor
+ with leading expert dim.
+
+ State-dict layout matches mixed_precision_ops.Linear with a leading
+ expert dim:
+ {prefix}.weight quant data (storage_t), leading dim = E
+ {prefix}.weight_scale block / per-tensor scale
+ {prefix}.weight_scale_2 [E] or scalar NVFP4 only
+ {prefix}.bias [E, out_features] optional, compute_dtype
+ {prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
+
+ Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in].
+ """
+
+ _disabled_formats = disabled
+
+ def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
+ super().__init__()
+ self.num_experts = num_experts
+ self.in_features = in_features
+ self.out_features = out_features
+ self._orig_shape = (num_experts, out_features, in_features)
+ self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
+ if bias:
+ self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
+ else:
+ self.register_parameter("bias", None)
+
+ # Populated by _load_from_state_dict:
+ self.weight = None
+ self.quant_format = None
+ self.layout_type = None
+ self._full_precision_mm = MixedPrecisionOps._full_precision_mm
+ self._full_precision_mm_config = False
+ self._resident_bank = None
+
+ def reset_parameters(self):
+ return None
+
+ def _apply(self, fn, recurse=True):
+ return _quantized_apply(self, fn, recurse)
+
+ def _load_from_state_dict(self, *args):
+ _load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False)
+
+ def expert_weight(self, i: int):
+ """Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
+ if isinstance(self.weight, QuantizedTensor):
+ return self._expert_qt_from(self.weight, i)
+ return self.weight[i]
+
+ @contextlib.contextmanager
+ def bank_resident(self, input):
+ """Cast the whole bank once; expert_linear inside reuses the cast.
+ Not re-entrant — do not nest calls on the same instance.
+ """
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ self._resident_bank = (weight, bias)
+ try:
+ yield self
+ finally:
+ self._resident_bank = None
+ uncast_bias_weight(self, weight, bias, offload_stream)
+
+ def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
+ """Linear against expert i's weight (with optional bias)."""
+ resident = getattr(self, "_resident_bank", None)
+ if resident is not None:
+ weight, bias = resident
+ return self._expert_linear_impl(input, weight, bias, i)
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ try:
+ return self._expert_linear_impl(input, weight, bias, i)
+ finally:
+ uncast_bias_weight(self, weight, bias, offload_stream)
+
+ def _expert_linear_impl(self, input, weight, bias, i):
+ if isinstance(weight, QuantizedTensor):
+ qw = self._expert_qt_from(weight, i)
+ else:
+ qw = weight[i]
+ b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
+
+ if isinstance(qw, QuantizedTensor):
+ use_fast = (
+ not self._full_precision_mm
+ and qw.layout_cls.supports_fast_matmul()
+ and input.dim() == 2
+ )
+ if use_fast:
+ qin = QuantizedTensor.from_float(input, self.layout_type)
+ return torch.nn.functional.linear(qin, qw, b)
+ out = input @ qw.dequantize().t()
+ return out + b if b is not None else out
+ return torch.nn.functional.linear(input, qw, b)
+
+ def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor:
+ """Build a per-expert QuantizedTensor by indexing into a resident bank."""
+ params = weight._params
+ kwargs = {
+ "scale": params.scale[i] if params.scale.dim() else params.scale,
+ "orig_dtype": params.orig_dtype,
+ "orig_shape": (self.out_features, self.in_features),
+ }
+ if hasattr(params, "block_scale"): # NVFP4
+ kwargs["block_scale"] = params.block_scale[i]
+ return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs))
+
+ def state_dict(self, *args, destination=None, prefix="", **kwargs):
+ sd = destination if destination is not None else {}
+ return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts})
class Embedding(manual_cast.Embedding):
- def _load_from_state_dict(self, state_dict, prefix, local_metadata,
- strict, missing_keys, unexpected_keys, error_msgs):
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
@@ -1343,14 +1439,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
# Only fp8 makes sense for embeddings (per-row dequant via index select).
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
- quant_format = layer_conf.get("format", None) if layer_conf is not None else None
- if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
+ quant_format = layer_conf.get("format") if layer_conf is not None else None
+ manually_loaded_keys = []
+
+ if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
weight = state_dict.pop(weight_key)
- manually_loaded_keys = [weight_key]
+ manually_loaded_keys.append(weight_key)
scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(scale_key, None)
@@ -1366,35 +1464,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
requires_grad=False)
+ elif layer_conf is not None:
+ # Unsupported format — restore the marker so it round-trips; fall through to default load.
+ state_dict[f"{prefix}comfy_quant"] = torch.tensor(
+ list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
- super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
- for k in manually_loaded_keys:
- if k in missing_keys:
- missing_keys.remove(k)
- else:
- if layer_conf is not None:
- state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
- super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+ for k in manually_loaded_keys:
+ if k in missing_keys:
+ missing_keys.remove(k)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
- if destination is not None:
- sd = destination
- else:
- sd = {}
-
- if not hasattr(self, 'weight') or self.weight is None:
- return sd
-
- if isinstance(self.weight, QuantizedTensor):
- sd_out = self.weight.state_dict("{}weight".format(prefix))
- for k in sd_out:
- sd[k] = sd_out[k]
-
- quant_conf = {"format": self.quant_format}
- sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
- else:
- sd["{}weight".format(prefix)] = self.weight
- return sd
+ sd = destination if destination is not None else {}
+ return _quantized_weight_state_dict(self, sd, prefix)
def forward_comfy_cast_weights(self, input, out_dtype=None):
weight = self.weight
diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py
index 5ee4d5ee5..189ee84ca 100644
--- a/comfy/patcher_extension.py
+++ b/comfy/patcher_extension.py
@@ -1,8 +1,9 @@
-from __future__ import annotations
from typing import Callable
class CallbacksMP:
ON_CLONE = "on_clone"
+ ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
+ ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup"
diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py
index 3782fd2d5..bdce2f2d8 100644
--- a/comfy/sampler_helpers.py
+++ b/comfy/sampler_helpers.py
@@ -1,16 +1,18 @@
from __future__ import annotations
+import torch
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
+import comfy.model_patcher
import comfy.utils
import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
- from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
+ from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device):
@@ -119,6 +121,47 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
+def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
+ '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
+ multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
+ if len(multigpu_models) == 0:
+ return
+ extra_devices = [x.load_device for x in multigpu_models]
+ # handle controlnets
+ controlnets: set[ControlBase] = set()
+ for k in conds:
+ for kk in conds[k]:
+ if 'control' in kk:
+ controlnets.add(kk['control'])
+ if len(controlnets) > 0:
+ # first, unload all controlnet clones
+ for cnet in list(controlnets):
+ cnet_models = cnet.get_models()
+ for cm in cnet_models:
+ comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
+
+ # next, make sure each controlnet has a deepclone for all relevant devices
+ for cnet in controlnets:
+ curr_cnet = cnet
+ while curr_cnet is not None:
+ for device in extra_devices:
+ if device not in curr_cnet.multigpu_clones:
+ curr_cnet.deepclone_multigpu(device, autoregister=True)
+ curr_cnet = curr_cnet.previous_controlnet
+ # since all device clones are now present, recreate the linked list for cloned cnets per device
+ for cnet in controlnets:
+ curr_cnet = cnet
+ while curr_cnet is not None:
+ prev_cnet = curr_cnet.previous_controlnet
+ for device in extra_devices:
+ device_cnet = curr_cnet.get_instance_for_device(device)
+ prev_device_cnet = None
+ if prev_cnet is not None:
+ prev_device_cnet = prev_cnet.get_instance_for_device(device)
+ device_cnet.set_previous_controlnet(prev_device_cnet)
+ curr_cnet = prev_cnet
+ # potentially handle gligen - since not widely used, ignored for now
+
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
@@ -143,7 +186,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
- real_model: BaseModel = None
+ model.match_multigpu_clones()
+ preprocess_multigpu_conds(conds, model, model_options)
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
@@ -155,7 +199,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
memory_required += inference_memory
minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
- real_model = model.model
+ real_model: BaseModel = model.model
return real_model, conds, models
@@ -201,3 +245,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False)
return to_load_options
+
+def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
+ '''
+ In case multigpu acceleration is enabled, prep ModelPatchers for each device.
+ '''
+ multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
+ if len(multigpu_patchers) > 0:
+ multigpu_dict: dict[torch.device, ModelPatcher] = {}
+ multigpu_dict[model_patcher.load_device] = model_patcher
+ for x in multigpu_patchers:
+ x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
+ x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
+ multigpu_dict[x.load_device] = x
+ model_options["multigpu_clones"] = multigpu_dict
+ return multigpu_patchers
diff --git a/comfy/samplers.py b/comfy/samplers.py
index c5e36ff05..e31277f7b 100755
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -1,7 +1,9 @@
from __future__ import annotations
+
+import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc
-from typing import TYPE_CHECKING, Callable, NamedTuple
+from typing import TYPE_CHECKING, Callable, NamedTuple, Any
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
@@ -16,6 +18,7 @@ import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
+import comfy.multigpu
import comfy.utils
import scipy.stats
import numpy
@@ -141,7 +144,7 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning)
-def cond_cat(c_list):
+def cond_cat(c_list, device=None):
temp = {}
for x in c_list:
for k in x:
@@ -153,6 +156,8 @@ def cond_cat(c_list):
for k in temp:
conds = temp[k]
out[k] = conds[0].concat(conds[1:])
+ if device is not None and hasattr(out[k], 'to'):
+ out[k] = out[k].to(device)
return out
@@ -212,7 +217,12 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
)
return executor.execute(model, conds, x_in, timestep, model_options)
-def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
+def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
+ # NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic
+ # (hooked_to_run accumulation, memory-fit batching, per-chunk output
+ # aggregation) is duplicated there with per-device scheduling layered on top.
+ if 'multigpu_clones' in model_options:
+ return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
out_conds = []
out_counts = []
# separate conds by matching hooks
@@ -244,7 +254,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
- model.current_patcher.prepare_state(timestep)
+ model.current_patcher.prepare_state(timestep, model_options)
# run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items():
@@ -344,6 +354,239 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
return out_conds
+def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
+ # NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks
+ # accumulation, memory-fit batching, and output aggregation, but adds a
+ # per-device scheduler, per-device patcher/control lookup, tensor .to(device)
+ # placement, and MultiGPUThreadPool dispatch around the inner loop.
+ out_conds = []
+ out_counts = []
+ # separate conds by matching hooks
+ hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
+ default_conds = []
+ has_default_conds = False
+
+ output_device = x_in.device
+
+ for i in range(len(conds)):
+ out_conds.append(torch.zeros_like(x_in))
+ out_counts.append(torch.ones_like(x_in) * 1e-37)
+
+ cond = conds[i]
+ default_c = []
+ if cond is not None:
+ for x in cond:
+ if 'default' in x:
+ default_c.append(x)
+ has_default_conds = True
+ continue
+ p = get_area_and_mult(x, x_in, timestep)
+ if p is None:
+ continue
+ if p.hooks is not None:
+ model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
+ hooked_to_run.setdefault(p.hooks, list())
+ hooked_to_run[p.hooks] += [(p, i)]
+ default_conds.append(default_c)
+
+ if has_default_conds:
+ finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
+
+ model.current_patcher.prepare_state(timestep, model_options)
+
+ devices = list(model_options['multigpu_clones'].keys())
+ device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
+ # Track conds currently scheduled per device; single source of truth for capacity checks.
+ device_load: dict[torch.device, int] = {d: 0 for d in devices}
+
+ total_conds = sum(len(to_run) for to_run in hooked_to_run.values())
+ conds_per_device = max(1, math.ceil(total_conds / len(devices)))
+
+ def next_available_device(start: int) -> tuple[int, torch.device]:
+ """Return (index, device) for the next device with remaining capacity, starting at `start`.
+
+ Scans at most len(devices) positions, so this always terminates. Raises if no device
+ has remaining capacity, which would indicate a bug in conds_per_device accounting.
+ """
+ for offset in range(len(devices)):
+ i = (start + offset) % len(devices)
+ if device_load[devices[i]] < conds_per_device:
+ return i, devices[i]
+ raise RuntimeError(
+ f"MultiGPU scheduler: all {len(devices)} devices at capacity "
+ f"({conds_per_device}) but conds remain to schedule"
+ )
+
+ # run every hooked_to_run separately
+ index_device = 0
+ for hooks, to_run in hooked_to_run.items():
+ while len(to_run) > 0:
+ index_device, current_device = next_available_device(index_device)
+ remaining_capacity = conds_per_device - device_load[current_device]
+
+ first = to_run[0]
+ first_shape = first[0][0].shape
+ # collect candidate indices that can be concatenated with `first`, up to remaining capacity
+ to_batch_temp = []
+ for x in range(len(to_run)):
+ if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
+ to_batch_temp += [x]
+
+ to_batch_temp.reverse()
+ to_batch = to_batch_temp[:1]
+
+ free_memory = comfy.model_management.get_free_memory(current_device)
+ for i in range(1, len(to_batch_temp) + 1):
+ batch_amount = to_batch_temp[:len(to_batch_temp)//i]
+ input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
+ cond_shapes = collections.defaultdict(list)
+ for tt in batch_amount:
+ for k, v in to_run[tt][0].conditioning.items():
+ cond_shapes[k].append(v.size())
+ if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
+ to_batch = batch_amount
+ break
+
+ conds_to_batch = [to_run.pop(x) for x in to_batch]
+ device_load[current_device] += len(conds_to_batch)
+ device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch))
+
+ if device_load[current_device] >= conds_per_device:
+ index_device += 1
+
+ class thread_result(NamedTuple):
+ output: Any
+ mult: Any
+ area: Any
+ batch_chunks: int
+ cond_or_uncond: Any
+ error: Exception = None
+
+ def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
+ try:
+ # TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
+ # we extend multigpu QA beyond CUDA. Unconditional call crashes on
+ # XPU/NPU/MPS/CPU/DirectML backends.
+ torch.cuda.set_device(device)
+ model_current: BaseModel = model_options["multigpu_clones"][device].model
+ # run every hooked_to_run separately
+ with torch.no_grad():
+ for hooks, to_batch in batch_tuple:
+ input_x = []
+ mult = []
+ c = []
+ cond_or_uncond = []
+ uuids = []
+ area = []
+ control: ControlBase = None
+ patches = None
+ for x in to_batch:
+ o = x
+ p = o[0]
+ input_x.append(p.input_x)
+ mult.append(p.mult)
+ c.append(p.conditioning)
+ area.append(p.area)
+ cond_or_uncond.append(o[1])
+ uuids.append(p.uuid)
+ control = p.control
+ patches = p.patches
+
+ batch_chunks = len(cond_or_uncond)
+ input_x = torch.cat(input_x).to(device)
+ c = cond_cat(c, device=device)
+ timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
+
+ transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
+ if 'transformer_options' in model_options:
+ transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
+ model_options['transformer_options'],
+ copy_dict1=False)
+
+ if patches is not None:
+ transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
+ transformer_options.get("patches", {}),
+ patches
+ )
+
+ transformer_options["cond_or_uncond"] = cond_or_uncond[:]
+ transformer_options["uuids"] = uuids[:]
+ transformer_options["sigmas"] = timestep.to(device)
+ transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
+ transformer_options["multigpu_thread_device"] = device
+
+ cast_transformer_options(transformer_options, device=device)
+ c['transformer_options'] = transformer_options
+
+ if control is not None:
+ device_control = control.get_instance_for_device(device)
+ c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
+
+ if 'model_function_wrapper' in model_options:
+ output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
+ else:
+ output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
+ # TODO: non-NVIDIA support -- the `.to(output_device)` copies
+ # above are async on CUDA, so the main thread's aggregation
+ # could race with in-flight transfers. CUDA-only QA has not
+ # surfaced this in practice, but before extending multigpu
+ # beyond NVIDIA add a `torch.cuda.synchronize(output_device)`
+ # here (guarded by `output_device.type == "cuda"`).
+ results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
+ except Exception as e:
+ results.append(thread_result(None, None, None, None, None, error=e))
+ raise
+
+
+ def _handle_batch_pooled(device, batch_tuple):
+ worker_results = []
+ _handle_batch(device, batch_tuple, worker_results)
+ return worker_results
+
+ results: list[thread_result] = []
+ thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
+
+ # Submit all GPU work to pool threads
+ pool_devices = []
+ for device, batch_tuple in device_batched_hooked_to_run.items():
+ if thread_pool is not None:
+ thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
+ pool_devices.append(device)
+ else:
+ # Fallback: no pool, run everything on main thread
+ _handle_batch(device, batch_tuple, results)
+
+ # Collect results from pool workers
+ for device in pool_devices:
+ worker_results, error = thread_pool.get_result(device)
+ if error is not None:
+ raise error
+ results.extend(worker_results)
+
+ for output, mult, area, batch_chunks, cond_or_uncond, error in results:
+ if error is not None:
+ raise error
+ for o in range(batch_chunks):
+ cond_index = cond_or_uncond[o]
+ a = area[o]
+ if a is None:
+ out_conds[cond_index] += output[o] * mult[o]
+ out_counts[cond_index] += mult[o]
+ else:
+ out_c = out_conds[cond_index]
+ out_cts = out_counts[cond_index]
+ dims = len(a) // 2
+ for i in range(dims):
+ out_c = out_c.narrow(i + 2, a[i + dims], a[i])
+ out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
+ out_c += output[o] * mult[o]
+ out_cts += mult[o]
+
+ for i in range(len(out_conds)):
+ out_conds[i] /= out_counts[i]
+
+ return out_conds
+
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
@@ -642,12 +885,21 @@ def calculate_start_end_timesteps(model, conds):
def pre_run_control(model, conds):
s = model.model_sampling
+ # Per-device model lookup so multigpu control clones get the matching
+ # diffusion_model (e.g. QwenFunControlNet stashes it into extra_args).
+ device_models: dict = {}
+ patcher = getattr(model, "current_patcher", None)
+ if patcher is not None:
+ for p in patcher.get_additional_models_with_key("multigpu"):
+ device_models[p.load_device] = p.model
for t in range(len(conds)):
x = conds[t]
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function)
+ for device, device_cnet in x['control'].multigpu_clones.items():
+ device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = []
@@ -890,7 +1142,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
to_load_options = model_options.get("to_load_options", None)
if to_load_options is None:
return
+ cast_transformer_options(to_load_options, device, dtype)
+def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
casts = []
if device is not None:
casts.append(device)
@@ -899,18 +1153,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# if nothing to apply, do nothing
if len(casts) == 0:
return
-
# try to call .to on patches
- if "patches" in to_load_options:
- patches = to_load_options["patches"]
+ if "patches" in transformer_options:
+ patches = transformer_options["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
for cast in casts:
patch_list[i] = patch_list[i].to(cast)
- if "patches_replace" in to_load_options:
- patches = to_load_options["patches_replace"]
+ if "patches_replace" in transformer_options:
+ patches = transformer_options["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
@@ -920,8 +1173,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks:
- if wc_name in to_load_options:
- wc: dict[str, list] = to_load_options[wc_name]
+ if wc_name in transformer_options:
+ wc: dict[str, list] = transformer_options[wc_name]
for wc_dict in wc.values():
for wc_list in wc_dict.values():
for i in range(len(wc_list)):
@@ -929,7 +1182,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
for cast in casts:
wc_list[i] = wc_list[i].to(cast)
-
class CFGGuider:
def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher
@@ -984,16 +1236,32 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
- noise = noise.to(device=device, dtype=torch.float32)
- latent_image = latent_image.to(device=device, dtype=torch.float32)
- sigmas = sigmas.to(device)
- cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
+ multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
- try:
- self.model_patcher.pre_run()
- output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
- finally:
- self.model_patcher.cleanup()
+ # Create persistent thread pool for all GPU devices (main + extras)
+ if multigpu_patchers:
+ extra_devices = [p.load_device for p in multigpu_patchers]
+ all_devices = [device] + extra_devices
+ self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
+
+ with comfy.model_management.cuda_device_context(device):
+ try:
+ noise = noise.to(device=device, dtype=torch.float32)
+ latent_image = latent_image.to(device=device, dtype=torch.float32)
+ sigmas = sigmas.to(device)
+ cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
+
+ self.model_patcher.pre_run()
+ for multigpu_patcher in multigpu_patchers:
+ multigpu_patcher.pre_run()
+ output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
+ finally:
+ thread_pool = self.model_options.pop("multigpu_thread_pool", None)
+ if thread_pool is not None:
+ thread_pool.shutdown()
+ self.model_patcher.cleanup()
+ for multigpu_patcher in multigpu_patchers:
+ multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model
diff --git a/comfy/sd.py b/comfy/sd.py
index 7bd07ed3a..30b877b85 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
import json
import torch
from enum import Enum
@@ -50,6 +49,7 @@ import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
+import comfy.text_encoders.pixeldit
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
@@ -69,6 +69,7 @@ import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo
import comfy.text_encoders.sa3
+import comfy.text_encoders.gpt_oss
import comfy.model_patcher
import comfy.lora
@@ -335,41 +336,43 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens)
- self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
+ device = self.patcher.load_device
+ self.cond_stage_model.set_clip_options({"execution_device": device})
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes))
- for scheduled_opts in scheduled_keyframes:
- t_range = scheduled_opts[0]
- # don't bother encoding any conds outside of start_percent and end_percent bounds
- if "start_percent" in add_dict:
- if t_range[1] < add_dict["start_percent"]:
- continue
- if "end_percent" in add_dict:
- if t_range[0] > add_dict["end_percent"]:
- continue
- hooks_keyframes = scheduled_opts[1]
- for hook, keyframe in hooks_keyframes:
- hook.hook_keyframe._current_keyframe = keyframe
- # apply appropriate hooks with values that match new hook_keyframe
- self.patcher.patch_hooks(all_hooks)
- # perform encoding as normal
- o = self.cond_stage_model.encode_token_weights(tokens)
- cond, pooled = o[:2]
- pooled_dict = {"pooled_output": pooled}
- # add clip_start_percent and clip_end_percent in pooled
- pooled_dict["clip_start_percent"] = t_range[0]
- pooled_dict["clip_end_percent"] = t_range[1]
- # add/update any keys with the provided add_dict
- pooled_dict.update(add_dict)
- # add hooks stored on clip
- self.add_hooks_to_dict(pooled_dict)
- all_cond_pooled.append([cond, pooled_dict])
- if show_pbar:
- pbar.update(1)
- model_management.throw_exception_if_processing_interrupted()
+ with model_management.cuda_device_context(device):
+ for scheduled_opts in scheduled_keyframes:
+ t_range = scheduled_opts[0]
+ # don't bother encoding any conds outside of start_percent and end_percent bounds
+ if "start_percent" in add_dict:
+ if t_range[1] < add_dict["start_percent"]:
+ continue
+ if "end_percent" in add_dict:
+ if t_range[0] > add_dict["end_percent"]:
+ continue
+ hooks_keyframes = scheduled_opts[1]
+ for hook, keyframe in hooks_keyframes:
+ hook.hook_keyframe._current_keyframe = keyframe
+ # apply appropriate hooks with values that match new hook_keyframe
+ self.patcher.patch_hooks(all_hooks)
+ # perform encoding as normal
+ o = self.cond_stage_model.encode_token_weights(tokens)
+ cond, pooled = o[:2]
+ pooled_dict = {"pooled_output": pooled}
+ # add clip_start_percent and clip_end_percent in pooled
+ pooled_dict["clip_start_percent"] = t_range[0]
+ pooled_dict["clip_end_percent"] = t_range[1]
+ # add/update any keys with the provided add_dict
+ pooled_dict.update(add_dict)
+ # add hooks stored on clip
+ self.add_hooks_to_dict(pooled_dict)
+ all_cond_pooled.append([cond, pooled_dict])
+ if show_pbar:
+ pbar.update(1)
+ model_management.throw_exception_if_processing_interrupted()
all_hooks.reset()
return all_cond_pooled
@@ -383,8 +386,12 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens)
- self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
- o = self.cond_stage_model.encode_token_weights(tokens)
+ device = self.patcher.load_device
+ self.cond_stage_model.set_clip_options({"execution_device": device})
+
+ with model_management.cuda_device_context(device):
+ o = self.cond_stage_model.encode_token_weights(tokens)
+
cond, pooled = o[:2]
if return_dict:
out = {"cond": cond, "pooled_output": pooled}
@@ -446,9 +453,12 @@ class CLIP:
self.cond_stage_model.reset_clip_options()
self.load_model(tokens)
+ device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"layer": None})
- self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
- return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
+ self.cond_stage_model.set_clip_options({"execution_device": device})
+
+ with model_management.cuda_device_context(device):
+ return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
@@ -1026,50 +1036,52 @@ class VAE:
do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0]
- try:
- memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
- model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
- free_memory = self.patcher.get_free_memory(self.device)
- batch_number = int(free_memory / memory_used)
- batch_number = max(1, batch_number)
- # Pre-allocate output for VAEs that support direct buffer writes
- preallocated = False
- if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
- pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
- preallocated = True
+ with model_management.cuda_device_context(self.device):
+ try:
+ memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
+ model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
+ free_memory = self.patcher.get_free_memory(self.device)
+ batch_number = int(free_memory / memory_used)
+ batch_number = max(1, batch_number)
- for x in range(0, samples_in.shape[0], batch_number):
- samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
- if preallocated:
- self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
- else:
- out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
- if pixel_samples is None:
- pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
- pixel_samples[x:x+batch_number].copy_(out)
- del out
- self.process_output(pixel_samples[x:x+batch_number])
- except Exception as e:
- model_management.raise_non_oom(e)
- logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
- #NOTE: We don't know what tensors were allocated to stack variables at the time of the
- #exception and the exception itself refs them all until we get out of this except block.
- #So we just set a flag for tiler fallback so that tensor gc can happen once the
- #exception is fully off the books.
- do_tile = True
+ # Pre-allocate output for VAEs that support direct buffer writes
+ preallocated = False
+ if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
+ pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
+ preallocated = True
- if do_tile:
- comfy.model_management.soft_empty_cache()
- dims = samples_in.ndim - 2
- if dims == 1 or self.extra_1d_channel is not None:
- pixel_samples = self.decode_tiled_1d(samples_in)
- elif dims == 2:
- pixel_samples = self.decode_tiled_(samples_in)
- elif dims == 3:
- tile = 256 // self.spacial_compression_decode()
- overlap = tile // 4
- pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
+ for x in range(0, samples_in.shape[0], batch_number):
+ samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
+ if preallocated:
+ self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
+ else:
+ out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
+ if pixel_samples is None:
+ pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
+ pixel_samples[x:x+batch_number].copy_(out)
+ del out
+ self.process_output(pixel_samples[x:x+batch_number])
+ except Exception as e:
+ model_management.raise_non_oom(e)
+ logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
+ #NOTE: We don't know what tensors were allocated to stack variables at the time of the
+ #exception and the exception itself refs them all until we get out of this except block.
+ #So we just set a flag for tiler fallback so that tensor gc can happen once the
+ #exception is fully off the books.
+ do_tile = True
+
+ if do_tile:
+ comfy.model_management.soft_empty_cache()
+ dims = samples_in.ndim - 2
+ if dims == 1 or self.extra_1d_channel is not None:
+ pixel_samples = self.decode_tiled_1d(samples_in)
+ elif dims == 2:
+ pixel_samples = self.decode_tiled_(samples_in)
+ elif dims == 3:
+ tile = 256 // self.spacial_compression_decode()
+ overlap = tile // 4
+ pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
@@ -1087,20 +1099,21 @@ class VAE:
if overlap is not None:
args["overlap"] = overlap
- if dims == 1 or self.extra_1d_channel is not None:
- args.pop("tile_y")
- output = self.decode_tiled_1d(samples, **args)
- elif dims == 2:
- output = self.decode_tiled_(samples, **args)
- elif dims == 3:
- if overlap_t is None:
- args["overlap"] = (1, overlap, overlap)
- else:
- args["overlap"] = (max(1, overlap_t), overlap, overlap)
- if tile_t is not None:
- args["tile_t"] = max(2, tile_t)
+ with model_management.cuda_device_context(self.device):
+ if dims == 1 or self.extra_1d_channel is not None:
+ args.pop("tile_y")
+ output = self.decode_tiled_1d(samples, **args)
+ elif dims == 2:
+ output = self.decode_tiled_(samples, **args)
+ elif dims == 3:
+ if overlap_t is None:
+ args["overlap"] = (1, overlap, overlap)
+ else:
+ args["overlap"] = (max(1, overlap_t), overlap, overlap)
+ if tile_t is not None:
+ args["tile_t"] = max(2, tile_t)
- output = self.decode_tiled_3d(samples, **args)
+ output = self.decode_tiled_3d(samples, **args)
return output.movedim(1, -1)
def encode(self, pixel_samples):
@@ -1113,44 +1126,46 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else:
pixel_samples = pixel_samples.unsqueeze(2)
- try:
- memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
- model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
- free_memory = self.patcher.get_free_memory(self.device)
- batch_number = int(free_memory / max(1, memory_used))
- batch_number = max(1, batch_number)
- samples = None
- for x in range(0, pixel_samples.shape[0], batch_number):
- pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
- if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
- out = self.first_stage_model.encode(pixels_in, device=self.device)
+
+ with model_management.cuda_device_context(self.device):
+ try:
+ memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
+ model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
+ free_memory = self.patcher.get_free_memory(self.device)
+ batch_number = int(free_memory / max(1, memory_used))
+ batch_number = max(1, batch_number)
+ samples = None
+ for x in range(0, pixel_samples.shape[0], batch_number):
+ pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
+ if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
+ out = self.first_stage_model.encode(pixels_in, device=self.device)
+ else:
+ pixels_in = pixels_in.to(self.device)
+ out = self.first_stage_model.encode(pixels_in)
+ out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
+ if samples is None:
+ samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
+ samples[x:x + batch_number] = out
+
+ except Exception as e:
+ model_management.raise_non_oom(e)
+ logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
+ #NOTE: We don't know what tensors were allocated to stack variables at the time of the
+ #exception and the exception itself refs them all until we get out of this except block.
+ #So we just set a flag for tiler fallback so that tensor gc can happen once the
+ #exception is fully off the books.
+ do_tile = True
+
+ if do_tile:
+ comfy.model_management.soft_empty_cache()
+ if self.latent_dim == 3:
+ tile = 256
+ overlap = tile // 4
+ samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
+ elif self.latent_dim == 1 or self.extra_1d_channel is not None:
+ samples = self.encode_tiled_1d(pixel_samples)
else:
- pixels_in = pixels_in.to(self.device)
- out = self.first_stage_model.encode(pixels_in)
- out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
- if samples is None:
- samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
- samples[x:x + batch_number] = out
-
- except Exception as e:
- model_management.raise_non_oom(e)
- logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
- #NOTE: We don't know what tensors were allocated to stack variables at the time of the
- #exception and the exception itself refs them all until we get out of this except block.
- #So we just set a flag for tiler fallback so that tensor gc can happen once the
- #exception is fully off the books.
- do_tile = True
-
- if do_tile:
- comfy.model_management.soft_empty_cache()
- if self.latent_dim == 3:
- tile = 256
- overlap = tile // 4
- samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
- elif self.latent_dim == 1 or self.extra_1d_channel is not None:
- samples = self.encode_tiled_1d(pixel_samples)
- else:
- samples = self.encode_tiled_(pixel_samples)
+ samples = self.encode_tiled_(pixel_samples)
return samples
@@ -1176,26 +1191,27 @@ class VAE:
if overlap is not None:
args["overlap"] = overlap
- if dims == 1:
- args.pop("tile_y")
- samples = self.encode_tiled_1d(pixel_samples, **args)
- elif dims == 2:
- samples = self.encode_tiled_(pixel_samples, **args)
- elif dims == 3:
- if tile_t is not None:
- tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
- else:
- tile_t_latent = 9999
- args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
+ with model_management.cuda_device_context(self.device):
+ if dims == 1:
+ args.pop("tile_y")
+ samples = self.encode_tiled_1d(pixel_samples, **args)
+ elif dims == 2:
+ samples = self.encode_tiled_(pixel_samples, **args)
+ elif dims == 3:
+ if tile_t is not None:
+ tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
+ else:
+ tile_t_latent = 9999
+ args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
- if overlap_t is None:
- args["overlap"] = (1, overlap, overlap)
- else:
- args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
- maximum = pixel_samples.shape[2]
- maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
+ if overlap_t is None:
+ args["overlap"] = (1, overlap, overlap)
+ else:
+ args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
+ maximum = pixel_samples.shape[2]
+ maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
- samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
+ samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return samples
@@ -1269,6 +1285,8 @@ class CLIPType(Enum):
FLUX2 = 25
LONGCAT_IMAGE = 26
COGVIDEOX = 27
+ LENS = 28
+ PIXELDIT = 29
@@ -1321,6 +1339,7 @@ class TEModel(Enum):
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
T5_GEMMA = 32
+ GPT_OSS_20B = 33
def detect_te_model(sd):
@@ -1362,6 +1381,9 @@ def detect_te_model(sd):
else:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
+ # Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
+ if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd:
+ return TEModel.GPT_OSS_20B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
if weight.shape[0] == 256:
@@ -1508,8 +1530,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B:
- clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
- clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
+ if clip_type == CLIPType.PIXELDIT:
+ clip_target.clip = comfy.text_encoders.pixeldit.pixeldit_te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer
+ else:
+ clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
@@ -1544,6 +1570,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
+ elif te_model == TEModel.GPT_OSS_20B:
+ clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
+ tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.QWEN3_4B:
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")
@@ -1710,12 +1740,52 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
- if output_model and out[0] is not None:
- out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
- if output_clip and out[1] is not None:
- out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options))
+ if out[0] is not None:
+ out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
+ # Register reload factories for the CLIP and VAE produced by the same checkpoint so
+ # ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device,
+ # MultiGPU work-units, etc.) without falling back to copy.deepcopy of an
+ # already-loaded module.
+ if out[1] is not None and getattr(out[1], "patcher", None) is not None:
+ out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
+ if out[2] is not None and getattr(out[2], "patcher", None) is not None:
+ out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
return out
+
+def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
+ """Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init
+ factory for the CLIP returned by load_checkpoint_guess_config."""
+ _, clip, _, _ = load_checkpoint_guess_config(
+ ckpt_path,
+ output_vae=False,
+ output_clip=True,
+ output_clipvision=False,
+ embedding_directory=embedding_directory,
+ output_model=False,
+ model_options=model_options,
+ te_model_options=te_model_options,
+ disable_dynamic=disable_dynamic,
+ )
+ return clip.patcher
+
+
+def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
+ """Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init
+ factory for the VAE returned by load_checkpoint_guess_config."""
+ _, _, vae, _ = load_checkpoint_guess_config(
+ ckpt_path,
+ output_vae=True,
+ output_clip=False,
+ output_clipvision=False,
+ embedding_directory=embedding_directory,
+ output_model=False,
+ model_options=model_options,
+ te_model_options=te_model_options,
+ disable_dynamic=disable_dynamic,
+ )
+ return vae.patcher
+
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory,
@@ -1742,7 +1812,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
- load_device = model_management.get_torch_device()
+ load_device = model_options.get("load_device", model_management.get_torch_device())
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
@@ -1782,13 +1852,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
- model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
+ offload_device = model_options.get("offload_device", model_management.unet_offload_device())
+ model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
- vae = VAE(sd=vae_sd, metadata=metadata)
+ vae_device = model_options.get("load_device", None)
+ vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device)
if output_clip:
if te_model_options.get("custom_operations", None) is None:
@@ -1872,7 +1944,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
- load_device = model_management.get_torch_device()
+ load_device = model_options.get("load_device", model_management.get_torch_device())
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None:
@@ -1897,7 +1969,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
else:
logging.warning("{} {}".format(diffusers_keys[k], k))
- offload_device = model_management.unet_offload_device()
+ offload_device = model_options.get("offload_device", model_management.unet_offload_device())
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.quant_config is not None:
weight_dtype = None
@@ -1939,6 +2011,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model
+
+def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False):
+ """Reload a disk-backed VAE from ``vae_path`` and return its patcher.
+
+ Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so
+ :meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
+ fresh, untainted VAE patcher (no inherited per-device load state, no
+ in-place quantization fallout) for multigpu work-units and the
+ SelectVAEDevice node. The optional ``device`` matches the source loader's
+ VAE initialization path; the deepclone's ``load_device`` still controls
+ where the cloned patcher is targeted.
+ """
+ if metadata is None:
+ sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
+ else:
+ sd = comfy.utils.load_torch_file(vae_path)
+ vae = VAE(sd=sd, metadata=metadata, device=device)
+ vae.throw_exception_if_invalid()
+ return vae.patcher
+
def load_unet(unet_path, dtype=None):
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 617db4f28..00941da53 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -30,6 +30,7 @@ import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo
import comfy.text_encoders.hidream_o1
+import comfy.text_encoders.pixeldit
from . import supported_models_base
from . import latent_formats
@@ -829,6 +830,50 @@ class Flux2(Flux):
return None
+
+class Lens(supported_models_base.BASE):
+ """Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE)."""
+
+ unet_config = {
+ "image_model": "lens",
+ }
+
+ sampling_settings = {
+ "shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300
+ }
+
+ unet_extra_config = {}
+ latent_format = latent_formats.Flux2
+
+ memory_usage_factor = 4.0
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def __init__(self, unet_config):
+ super().__init__(unet_config)
+
+ def get_model(self, state_dict, prefix="", device=None):
+ return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device)
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ for hint in ("gpt_oss.transformer.", ""):
+ full_prefix = "{}{}".format(pref, hint)
+ if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict:
+ detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
+ return supported_models_base.ClipTarget(
+ comfy.text_encoders.gpt_oss.LensTokenizer,
+ comfy.text_encoders.gpt_oss.lens_te(**detect),
+ )
+ return supported_models_base.ClipTarget(
+ comfy.text_encoders.gpt_oss.LensTokenizer,
+ comfy.text_encoders.gpt_oss.lens_te(),
+ )
+
+
class GenmoMochi(supported_models_base.BASE):
unet_config = {
"image_model": "mochi_preview",
@@ -1159,6 +1204,72 @@ class ZImagePixelSpace(ZImage):
def get_model(self, state_dict, prefix="", device=None):
return model_base.ZImagePixelSpace(self, device=device)
+class PixelDiTT2I(supported_models_base.BASE):
+ unet_config = {
+ "image_model": "pixeldit_t2i",
+ }
+
+ unet_extra_config = {}
+
+ sampling_settings = {
+ "shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
+ }
+
+ latent_format = latent_formats.PixelDiTPixel
+ memory_usage_factor = 0.04
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ return model_base.PixelDiTT2I(self, device=device)
+
+ def process_unet_state_dict(self, state_dict):
+ # pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim).
+ pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0]
+
+ out = {}
+ marker = ".adaLN_modulation.0."
+ for k, v in state_dict.items():
+ if k.startswith("_repa_projector") or k.startswith("net_ema."):
+ continue
+ if k.startswith("core."):
+ k = k[len("core."):]
+ elif k.startswith("net."):
+ k = k[len("net."):]
+ if "pixel_blocks." in k and marker in k:
+ # Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
+ p2 = v.shape[0] // (6 * pixel_dim)
+ trail = v.shape[1:] # () for bias, (in_dim,) for weight
+ vv = v.view(p2, 6, pixel_dim, *trail)
+ base, suffix = k.split(marker)
+ out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous()
+ out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous()
+ else:
+ out[k] = v
+ return out
+
+ def clip_target(self, state_dict={}):
+ return supported_models_base.ClipTarget(
+ comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer,
+ comfy.text_encoders.pixeldit.PixelDiTGemma2TE,
+ )
+
+class PiD(PixelDiTT2I):
+ unet_config = {
+ "image_model": "pid",
+ }
+
+ sampling_settings = {
+ "shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
+ }
+
+ memory_usage_factor = 0.04
+
+ def get_model(self, state_dict, prefix="", device=None):
+ return model_base.PiD(self, device=device)
+
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@@ -2069,6 +2180,8 @@ models = [
CosmosI2VPredict2,
ZImagePixelSpace,
ZImage,
+ PiD,
+ PixelDiTT2I,
Lumina2,
WAN22_T2V,
WAN21_CausalAR_T2V,
@@ -2096,6 +2209,7 @@ models = [
Omnigen2,
QwenImage,
Flux2,
+ Lens,
Kandinsky5Image,
Kandinsky5,
Anima,
diff --git a/comfy/text_encoders/gpt_oss.py b/comfy/text_encoders/gpt_oss.py
new file mode 100644
index 000000000..d596ef9a0
--- /dev/null
+++ b/comfy/text_encoders/gpt_oss.py
@@ -0,0 +1,600 @@
+"""GPT-OSS text encoder for Lens."""
+
+from __future__ import annotations
+
+import math
+from dataclasses import dataclass
+from typing import Any, List, Optional, Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import comfy.ops
+from comfy import sd1_clip
+from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
+from comfy.text_encoders.llama import RMSNorm, apply_rope
+
+
+@dataclass
+class GptOss20BConfig:
+ vocab_size: int = 201088
+ hidden_size: int = 2880
+ intermediate_size: int = 2880
+ num_hidden_layers: int = 24
+ num_attention_heads: int = 64
+ num_key_value_heads: int = 8
+ head_dim: int = 64
+ num_local_experts: int = 32
+ num_experts_per_tok: int = 4
+ sliding_window: int = 128
+ original_max_position_embeddings: int = 4096
+ rope_theta: float = 150000.0
+ rope_factor: float = 32.0
+ rope_beta_fast: float = 32.0
+ rope_beta_slow: float = 1.0
+ rope_truncate: bool = False
+ rms_norm_eps: float = 1e-5
+ attention_bias: bool = True
+ layer_types: Optional[List[str]] = None
+ moe_alpha: float = 1.702
+ moe_limit: float = 7.0
+
+ def __post_init__(self):
+ if self.layer_types is None:
+ self.layer_types = [
+ "sliding_attention" if (i + 1) % 2 else "full_attention"
+ for i in range(self.num_hidden_layers)
+ ]
+
+
+def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float,
+ original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]:
+ """YARN inv_freq + attention scaling (matches transformers)."""
+ dim = head_dim
+
+ def find_correction_dim(num_rotations: float) -> float:
+ return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (
+ 2 * math.log(base)
+ )
+
+ def find_correction_range() -> tuple[float, float]:
+ low = find_correction_dim(beta_fast)
+ high = find_correction_dim(beta_slow)
+ if truncate:
+ low = math.floor(low)
+ high = math.ceil(high)
+ return max(low, 0), min(high, dim - 1)
+
+ def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor:
+ if min_ == max_:
+ max_ += 0.001
+ linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_)
+ return torch.clamp(linear, 0, 1)
+
+ def get_mscale(scale: float) -> float:
+ if scale <= 1:
+ return 1.0
+ return 0.1 * math.log(scale) + 1.0
+
+ attention_scaling = get_mscale(factor)
+
+ pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
+ inv_freq_extrapolation = 1.0 / pos_freqs
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
+
+ low, high = find_correction_range()
+ extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2)
+ inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor
+ return inv_freq, attention_scaling
+
+
+def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ pos_e = position_ids[:, None, :].float()
+ freqs = (inv_freq_e @ pos_e).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1)
+ sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1)
+ sin_split = sin.shape[-1] // 2
+ return cos, sin[..., :sin_split], -sin[..., sin_split:]
+
+
+def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor,
+ attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor:
+ """Attention with per-head sinks.
+
+ Sinks add a learned term to each row's softmax denominator but contribute
+ nothing to the output. We fake this by appending one zero k/v position and
+ putting the sink logit in the mask at that column.
+ """
+
+ if num_kv_groups > 1 and not TORCH_HAS_GQA:
+ k = k.repeat_interleave(num_kv_groups, dim=1)
+ v = v.repeat_interleave(num_kv_groups, dim=1)
+
+ B, _, S_q, D = q.shape
+ H_kv = k.shape[1]
+ S_kv = k.shape[-2]
+
+ k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2)
+ v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2)
+
+ sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1)
+ if attention_mask is not None:
+ mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv)
+ else:
+ mask_left = q.new_zeros(B, num_heads, S_q, S_kv)
+ mask = torch.cat([mask_left, sinks_col], dim=-1)
+
+ op = optimized_attention_for_device(q.device, mask=True, small_input=True)
+ return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True)
+
+
+class GptOssAttention(nn.Module):
+ def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.layer_type = config.layer_types[layer_idx]
+ self.num_heads = config.num_attention_heads
+ self.num_kv_heads = config.num_key_value_heads
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
+ self.head_dim = config.head_dim
+ self.hidden_size = config.hidden_size
+ self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
+
+ bias = config.attention_bias
+ self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
+ self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
+ self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
+ self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype)
+ self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype))
+
+ def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor:
+ B, S, _ = hidden_states.shape
+
+ q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
+ k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
+ v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
+
+ q, k = apply_rope(q, k, freqs_cis)
+
+ out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups)
+ return self.o_proj(out)
+
+
+# Mixture of Experts
+
+class GptOssTopKRouter(nn.Module):
+ def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
+ super().__init__()
+ self.top_k = config.num_experts_per_tok
+ self.num_experts = config.num_local_experts
+ self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
+ self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
+
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False)
+ bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False)
+ logits = F.linear(hidden_states, weight, bias)
+ top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
+ # Softmax over top-k slice only
+ scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
+ return scores, top_idx
+
+
+class GptOssExperts(nn.Module):
+ def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
+ super().__init__()
+ self.num_experts = config.num_local_experts
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.alpha = config.moe_alpha
+ self.limit = config.moe_limit
+
+ E = self.num_experts
+ H = self.hidden_size
+ I = self.intermediate_size
+
+ self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
+ self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)
+
+ def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
+ gate = gate_up[..., ::2]
+ up = gate_up[..., 1::2]
+ gate = gate.clamp(max=self.limit)
+ up = up.clamp(min=-self.limit, max=self.limit)
+ glu = gate * torch.sigmoid(gate * self.alpha)
+ return torch.addcmul(glu, up, glu)
+
+ def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
+ N = hidden_states.shape[0]
+ top_k = router_indices.shape[-1]
+ H = hidden_states.shape[-1]
+
+ per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device)
+
+ expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+ with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \
+ self.down_proj.bank_resident(hidden_states) as down_bank:
+ for ei in expert_hit:
+ expert_idx = int(ei.item())
+ top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
+ current = hidden_states[token_idx]
+
+ gate_up = gate_up_bank.expert_linear(current, expert_idx)
+ gated = self._apply_gate(gate_up)
+ expert_out = down_bank.expert_linear(gated, expert_idx)
+
+ weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
+
+ flat_idx = token_idx * top_k + top_k_pos
+ per_pair[flat_idx] = weighted.to(per_pair.dtype)
+
+ return per_pair.view(N, top_k, H).sum(dim=1)
+
+
+class GptOssMLP(nn.Module):
+ def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
+ super().__init__()
+ self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
+ self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ B, S, H = hidden_states.shape
+ flat = hidden_states.reshape(-1, H)
+ scores, idx = self.router(flat)
+ out = self.experts(flat, idx, scores)
+ return out.reshape(B, S, H)
+
+
+# Decoder layer + model
+
+class GptOssDecoderLayer(nn.Module):
+ def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
+ super().__init__()
+ self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
+ self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ self.layer_type = config.layer_types[layer_idx]
+
+ def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor:
+ residual = x
+ x = self.input_layernorm(x)
+ x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis)
+ x = residual + x
+
+ residual = x
+ x = self.post_attention_layernorm(x)
+ x = self.mlp(x)
+ x = residual + x
+ return x
+
+
+def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
+ neg = torch.finfo(dtype).min
+ mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1)
+ mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
+ if key_padding_mask is not None:
+ kp = key_padding_mask.to(dtype=dtype)
+ kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
+ mask = mask + kp
+ return mask
+
+
+def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
+ neg = torch.finfo(dtype).min
+ i = torch.arange(S, device=device).view(-1, 1)
+ j = torch.arange(S, device=device).view(1, -1)
+ keep = (j <= i) & (j > i - window)
+ mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device))
+ mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
+ if key_padding_mask is not None:
+ kp = key_padding_mask.to(dtype=dtype)
+ kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
+ mask = mask + kp
+ return mask
+
+
+class GptOssModel(nn.Module):
+ """GPT-OSS decoder with multi-layer hidden-state capture + early exit."""
+
+ def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
+ super().__init__()
+ self.config = config
+ self.dtype = dtype
+ self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
+ self.layers = nn.ModuleList(
+ [
+ GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops)
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+
+ # Always build on CPU so the buffer survives meta-device construction.
+ inv_freq, attn_scaling = _yarn_inv_freq(
+ head_dim=config.head_dim,
+ base=config.rope_theta,
+ factor=config.rope_factor,
+ beta_fast=config.rope_beta_fast,
+ beta_slow=config.rope_beta_slow,
+ original_max_position_embeddings=config.original_max_position_embeddings,
+ truncate=config.rope_truncate,
+ device=torch.device("cpu"),
+ )
+ self.register_buffer("rope_inv_freq", inv_freq, persistent=False)
+ self.rope_attention_scaling = float(attn_scaling)
+
+ @property
+ def num_layers(self) -> int:
+ return self.config.num_hidden_layers
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device,
+ ) -> dict[str, torch.Tensor]:
+ full = _make_full_causal_mask(B, S, attention_mask, dtype, device)
+ masks = {"full_attention": full}
+ if any(t == "sliding_attention" for t in self.config.layer_types):
+ masks["sliding_attention"] = _make_sliding_causal_mask(
+ B, S, self.config.sliding_window, attention_mask, dtype, device
+ )
+ return masks
+
+ def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None,
+ capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]:
+ B, S = input_ids.shape
+ device = input_ids.device
+ dtype = self.dtype
+
+ hidden_states = self.embed_tokens(input_ids, out_dtype=dtype)
+
+ position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
+ freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype)
+
+ attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device)
+
+ capture_layers = list(capture_layers) if capture_layers else None
+ if capture_layers:
+ max_layer = max(capture_layers)
+ wanted = {idx: pos for pos, idx in enumerate(capture_layers)}
+ captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers)
+ else:
+ max_layer = self.config.num_hidden_layers - 1
+ wanted = None
+ captured = None
+
+ for i, layer in enumerate(self.layers):
+ hidden_states = layer(hidden_states, attn_masks, freqs_cis)
+ if wanted is not None and i in wanted:
+ captured[wanted[i]] = hidden_states
+ if i >= max_layer:
+ break
+
+ if captured is not None:
+ return {"hidden_states": captured}
+ return {"last_hidden_state": self.norm(hidden_states)}
+
+
+# Lens chat-template constants (verbatim from the reference pipeline).
+_LENS_CHAT_SYSTEM = (
+ "Describe the image by detailing the color, shape, size, texture, "
+ "quantity, text, spatial relationships of the objects and background."
+)
+_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
+LENS_TXT_OFFSET = 97
+LENS_SELECTED_LAYERS = (5, 11, 17, 23)
+LENS_MAX_TOKENS = 512
+
+
+# The reference GPT-OSS Harmony template injects today's date here
+_LENS_CHAT_DATE = "2026-05-23"
+
+
+def _lens_render_chat(prompt: str) -> str:
+ """Render the Lens prompt in GPT-OSS Harmony format."""
+ return (
+ f"<|start|>system<|message|>"
+ f"You are ChatGPT, a large language model trained by OpenAI.\n"
+ f"Knowledge cutoff: 2024-06\n"
+ f"Current date: {_LENS_CHAT_DATE}\n\n"
+ f"Reasoning: medium\n\n"
+ f"# Valid channels: analysis, commentary, final. "
+ f"Channel must be included for every message.<|end|>"
+ f"<|start|>developer<|message|># Instructions\n\n"
+ f"{_LENS_CHAT_SYSTEM}\n\n<|end|>"
+ f"<|start|>user<|message|>{prompt}<|end|>"
+ f"<|start|>assistant<|channel|>analysis<|message|>"
+ f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>"
+ f"<|start|>assistant<|channel|>final<|message|>"
+ )
+
+
+# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table).
+_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|>
+
+
+class _GptOssRawTokenizer:
+ """Raw ``tokenizers.Tokenizer`` wrapper.
+
+ The tokenizer JSON ships as a byte tensor inside the encoder checkpoint
+ (``tokenizer_json`` key) rather than as a committed file. Extracted
+ it in ``sd.py`` and passes it here via ``tokenizer_data``.
+ """
+
+ def __init__(self, tokenizer_json_bytes=None, **kwargs):
+ from tokenizers import Tokenizer
+ if isinstance(tokenizer_json_bytes, torch.Tensor):
+ tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
+ if tokenizer_json_bytes is None:
+ raise ValueError(
+ "Lens tokenizer requires the ``tokenizer_json`` byte tensor in the "
+ "encoder state dict. Re-bundle the encoder via bundle_te.py so it "
+ "embeds the tokenizer."
+ )
+ self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
+
+ @classmethod
+ def from_pretrained(cls, tokenizer_data, **kwargs):
+ return cls(tokenizer_json_bytes=tokenizer_data, **kwargs)
+
+ def __call__(self, text):
+ return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids}
+
+ def get_vocab(self):
+ return self.tokenizer.get_vocab()
+
+ def convert_tokens_to_ids(self, tokens):
+ return [self.tokenizer.token_to_id(t) for t in tokens]
+
+ def decode(self, ids, **kwargs):
+ return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False))
+
+
+class LensGptOssTokenizer(sd1_clip.SDTokenizer):
+ tokenizer_json_data = None
+
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ tokenizer_json = tokenizer_data.get("tokenizer_json", None)
+ self.tokenizer_json_data = tokenizer_json
+ super().__init__(
+ tokenizer_json,
+ embedding_directory=embedding_directory,
+ pad_with_end=False,
+ embedding_size=2880,
+ embedding_key="gpt_oss",
+ tokenizer_class=_GptOssRawTokenizer,
+ has_start_token=False,
+ has_end_token=False,
+ pad_to_max_length=False,
+ max_length=99999999,
+ min_length=1,
+ pad_left=False,
+ disable_weights=True,
+ tokenizer_data=tokenizer_data,
+ )
+ self.pad_token_id = _LENS_PAD_TOKEN_ID
+
+ def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
+ # Empty prompt -> empty list; encode_token_weights returns zeros (uncond).
+ if not text or not text.strip():
+ return [[]]
+ rendered = _lens_render_chat(text)
+ ids = self.tokenizer(rendered)["input_ids"]
+ if len(ids) > LENS_MAX_TOKENS:
+ ids = ids[:LENS_MAX_TOKENS]
+ return [[(int(t), 1.0) for t in ids]]
+
+ def state_dict(self):
+ if self.tokenizer_json_data is not None:
+ return {"tokenizer_json": self.tokenizer_json_data}
+ return {}
+
+
+class LensTokenizer(sd1_clip.SD1Tokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(
+ embedding_directory=embedding_directory,
+ tokenizer_data=tokenizer_data,
+ name="gpt_oss",
+ tokenizer=LensGptOssTokenizer,
+ )
+
+
+class LensGptOssClipModel(nn.Module):
+ """SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
+
+ def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
+ super().__init__()
+ model_options = dict(model_options or {})
+
+ operations = model_options.get("custom_operations")
+ if operations is None:
+ quant_config = model_options.get("quantization_metadata") or {}
+ operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
+ self.operations = operations
+
+ cfg_overrides = model_options.get("gpt_oss_config", {})
+ self.config = GptOss20BConfig(**cfg_overrides)
+ self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
+ self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
+
+ self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
+ self.num_layers = self.config.num_hidden_layers
+ self.dtype = dtype
+ self.execution_device = None
+ self._pad_token_id = _LENS_PAD_TOKEN_ID
+
+ def set_clip_options(self, options):
+ self.execution_device = options.get("execution_device", self.execution_device)
+
+ def reset_clip_options(self):
+ self.execution_device = None
+
+ def _gather_tokens(self, token_weight_pairs):
+ ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs]
+ pad_id = self._pad_token_id
+ max_len = max(len(x) for x in ids_list)
+ device = self.execution_device
+ ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device)
+ mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device)
+ for i, x in enumerate(ids_list):
+ ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device)
+ mask[i, : len(x)] = 1
+ return ids, mask
+
+ def encode_token_weights(self, token_weight_pairs):
+ # Empty negative: emit zero-length features + zero mask
+ if all(len(batch) == 0 for batch in token_weight_pairs):
+ device = self.execution_device
+ B = len(token_weight_pairs)
+ L = len(self.selected_layers)
+ H = self.config.hidden_size
+ flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device)
+ mask = torch.zeros(B, 0, dtype=torch.long, device=device)
+ return flat, None, {"attention_mask": mask, "num_layers_stacked": L}
+
+ input_ids, attn_mask = self._gather_tokens(token_weight_pairs)
+ out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers)
+ layers = out["hidden_states"] # list of L × [B, S, H]
+ stacked = torch.stack(layers, dim=2) # [B, S, L, H]
+
+ offset = self.txt_offset
+ if stacked.shape[1] > offset:
+ stacked = stacked[:, offset:].contiguous()
+ mask_trim = attn_mask[:, offset:]
+ else:
+ stacked = stacked[:, :0]
+ mask_trim = attn_mask[:, :0]
+
+ B, S, L, H = stacked.shape
+ flat = stacked.reshape(B, S, L * H)
+ extra = {"attention_mask": mask_trim, "num_layers_stacked": L}
+ return flat, None, extra
+
+ def load_sd(self, sd):
+ return self.transformer.load_state_dict(sd, strict=False, assign=True)
+
+
+class LensTEModel(sd1_clip.SD1ClipModel):
+ def __init__(self, device="cpu", dtype=None, model_options=None):
+ super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
+
+
+def lens_te(dtype_llama=None, llama_quantization_metadata=None):
+ class LensTEModel_(LensTEModel):
+ def __init__(self, device="cpu", dtype=None, model_options=None):
+ mo = dict(model_options or {})
+ if llama_quantization_metadata is not None:
+ mo["quantization_metadata"] = llama_quantization_metadata
+ if dtype is None and dtype_llama is not None:
+ dtype = dtype_llama
+ super().__init__(device=device, dtype=dtype, model_options=mo)
+
+ return LensTEModel_
diff --git a/comfy/text_encoders/pixeldit.py b/comfy/text_encoders/pixeldit.py
new file mode 100644
index 000000000..3539711e4
--- /dev/null
+++ b/comfy/text_encoders/pixeldit.py
@@ -0,0 +1,104 @@
+import torch
+
+from comfy import sd1_clip
+from .lumina2 import Gemma2BTokenizer, LuminaModel
+import comfy.text_encoders.llama
+
+
+class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["quantization_metadata"] = llama_quantization_metadata
+
+ super().__init__(
+ device=device, layer=layer, layer_idx=layer_idx,
+ textmodel_json_config={}, dtype=dtype,
+ special_tokens={"start": 2, "pad": 0},
+ layer_norm_hidden_state=False,
+ model_class=comfy.text_encoders.llama.Gemma2_2B,
+ enable_attention_masks=attention_mask,
+ return_attention_masks=attention_mask,
+ model_options=model_options,
+ )
+
+
+_PIXELDIT_CHI_PROMPT = (
+ 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions '
+ "suitable for image generation. Evaluate the level of detail in the user prompt:\n"
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, "
+ "and spatial relationships to create vivid and concrete scenes.\n"
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without "
+ "overcomplicating.\n"
+ "Here are examples of how to transform or refine prompts:\n"
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, "
+ "sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring "
+ "glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus "
+ "passing by towering glass skyscrapers.\n"
+ "Please generate only the enhanced description for the prompt below and avoid including any "
+ "additional commentary or evaluations:\n"
+ "User Prompt: "
+)
+
+_PIXELDIT_MAX_LENGTH = 300
+_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"'
+
+
+class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data=None):
+ if tokenizer_data is None:
+ tokenizer_data = {}
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
+ name="gemma2_2b", tokenizer=Gemma2BTokenizer)
+
+ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
+ if not text.strip():
+ return super().tokenize_with_weights("", return_word_ids=return_word_ids, disable_weights=True, min_length=_PIXELDIT_MAX_LENGTH)
+
+ chi_token_count = len(self.gemma2_2b.tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"])
+ combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text
+ max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2
+ out = super().tokenize_with_weights(combined, return_word_ids=return_word_ids,
+ disable_weights=True, min_length=max_length_all)
+ out["gemma2_2b"] = [out["gemma2_2b"][0][:max_length_all]]
+ return out
+
+ def untokenize(self, token_weight_pair):
+ return self.gemma2_2b.untokenize(token_weight_pair)
+
+ def state_dict(self):
+ return self.gemma2_2b.state_dict()
+
+
+class PixelDiTGemma2TE(LuminaModel):
+ # PixelDiT's select_index: keep BOS + last 299 embeddings of the padded sequence.
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ super().__init__(device=device, dtype=dtype, name="gemma2_2b",
+ clip_model=PixelDiTGemma2_2BModel, model_options=model_options)
+
+ def encode_token_weights(self, token_weight_pairs):
+ result = super().encode_token_weights(token_weight_pairs)
+ cond, pooled = result[0], result[1]
+ extra = result[2] if len(result) > 2 else None
+ if cond.shape[1] > _PIXELDIT_MAX_LENGTH:
+ cond = torch.cat([cond[:, :1], cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]], dim=1)
+ if extra is not None and "attention_mask" in extra:
+ am = extra["attention_mask"]
+ extra["attention_mask"] = torch.cat([am[..., :1], am[..., -(_PIXELDIT_MAX_LENGTH - 1):]], dim=-1)
+ if extra is not None:
+ return cond, pooled, extra
+ return cond, pooled
+
+
+def pixeldit_te(dtype_llama=None, llama_quantization_metadata=None):
+ class PixelDiTTE_(PixelDiTGemma2TE):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
+ if dtype_llama is not None:
+ dtype = dtype_llama
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
+ return PixelDiTTE_
diff --git a/comfy/utils.py b/comfy/utils.py
index 31052714a..49ae12b06 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -86,6 +86,7 @@ def load_safetensors(ckpt):
import comfy_aimdo.model_mmap
f = open(ckpt, "rb", buffering=0)
+ file_lock = threading.Lock()
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
@@ -111,7 +112,7 @@ def load_safetensors(ckpt):
storage = tensor.untyped_storage()
setattr(storage,
"_comfy_tensor_file_slice",
- comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
+ comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
sd[name] = tensor
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index 04973fea0..e0a585b10 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py
index 942278d88..99e67d363 100644
--- a/comfy_api/latest/_input_impl/video_types.py
+++ b/comfy_api/latest/_input_impl/video_types.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py
index c92477f08..6c9d6a526 100644
--- a/comfy_api/latest/_util/video_types.py
+++ b/comfy_api/latest/_util/video_types.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py
index 46a583b5e..9c4cfb9b6 100644
--- a/comfy_api_nodes/apis/__init__.py
+++ b/comfy_api_nodes/apis/__init__.py
@@ -3,7 +3,6 @@
# timestamp: 2025-07-30T08:54:00+00:00
# pylint: disable
-from __future__ import annotations
from datetime import date, datetime
from enum import Enum
diff --git a/comfy_api_nodes/apis/bfl.py b/comfy_api_nodes/apis/bfl.py
index d8d3557b3..f0665fa09 100644
--- a/comfy_api_nodes/apis/bfl.py
+++ b/comfy_api_nodes/apis/bfl.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
from enum import Enum
from typing import Any, Dict, Optional
diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py
index 03f4c445b..47f24586c 100644
--- a/comfy_api_nodes/apis/bytedance.py
+++ b/comfy_api_nodes/apis/bytedance.py
@@ -158,8 +158,9 @@ class SeedanceCreateAssetResponse(BaseModel):
class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
- url: str = Field(..., description="Publicly accessible URL of the image asset to upload.")
+ url: str = Field(..., description="Publicly accessible URL of the asset to upload.")
hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.")
+ asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.")
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
diff --git a/comfy_api_nodes/apis/krea.py b/comfy_api_nodes/apis/krea.py
new file mode 100644
index 000000000..6e294a3b7
--- /dev/null
+++ b/comfy_api_nodes/apis/krea.py
@@ -0,0 +1,46 @@
+"""Pydantic models for the Krea image-generation API."""
+
+from pydantic import BaseModel, Field
+
+
+class KreaMoodboard(BaseModel):
+ id: str = Field(...)
+ strength: float = Field(default=0.35, ge=-0.5, le=1.5)
+
+
+class KreaImageStyleReference(BaseModel):
+ strength: float = Field(..., ge=-2.0, le=2.0)
+ url: str | None = Field(default=None)
+
+
+class KreaGenerateImageRequest(BaseModel):
+ prompt: str = Field(...)
+ aspect_ratio: str = Field(...)
+ resolution: str = Field(...)
+ seed: int | None = Field(default=None)
+ creativity: str = Field(default="medium")
+ moodboards: list[KreaMoodboard] | None = Field(default=None)
+ image_style_references: list[KreaImageStyleReference] | None = Field(default=None)
+
+
+class KreaJobResult(BaseModel):
+ urls: list[str] | None = Field(default=None)
+ style_id: str | None = Field(default=None)
+
+
+class KreaJob(BaseModel):
+ job_id: str = Field(...)
+ status: str = Field(...)
+ created_at: str = Field(...)
+ completed_at: str | None = Field(default=None)
+ result: KreaJobResult | None = Field(default=None)
+
+
+class KreaAssetResponse(BaseModel):
+ id: str = Field(...)
+ image_url: str = Field(...)
+ uploaded_at: str = Field(...)
+ width: float | None = Field(default=None)
+ height: float | None = Field(default=None)
+ size_bytes: float | None = Field(default=None)
+ mime_type: str | None = Field(default=None)
diff --git a/comfy_api_nodes/apis/stability.py b/comfy_api_nodes/apis/stability.py
index 718360187..5b9b5ac7d 100644
--- a/comfy_api_nodes/apis/stability.py
+++ b/comfy_api_nodes/apis/stability.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
from enum import Enum
from typing import Optional
diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py
index e08fc0b01..1b4abafe6 100644
--- a/comfy_api_nodes/nodes_bytedance.py
+++ b/comfy_api_nodes/nodes_bytedance.py
@@ -2,11 +2,12 @@ import hashlib
import logging
import math
import re
+from io import BytesIO
import torch
from typing_extensions import override
-from comfy_api.latest import IO, ComfyExtension, Input
+from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.bytedance import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
@@ -308,6 +309,26 @@ async def _seedance_virtual_library_upload_image_asset(
return f"asset://{create_resp.asset_id}"
+async def _seedance_virtual_library_upload_video_asset(
+ cls: type[IO.ComfyNode],
+ video: Input.Video,
+ *,
+ wait_label: str = "Uploading video",
+) -> str:
+ buf = BytesIO()
+ video.save_to(buf, format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264)
+ video_hash = hashlib.sha256(buf.getbuffer()).hexdigest()
+ public_url = await upload_video_to_comfyapi(cls, video, wait_label=wait_label)
+ create_resp = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/seedance/virtual-library/assets", method="POST"),
+ response_model=SeedanceCreateAssetResponse,
+ data=SeedanceVirtualLibraryCreateAssetRequest(url=public_url, hash=video_hash, asset_type="Video"),
+ )
+ await _wait_for_asset_active(cls, create_resp.asset_id, group_id="virtual-library")
+ return f"asset://{create_resp.asset_id}"
+
+
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
@@ -2106,7 +2127,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
content.append(
TaskVideoContent(
video_url=TaskVideoContentUrl(
- url=await upload_video_to_comfyapi(
+ url=await _seedance_virtual_library_upload_video_asset(
cls,
reference_videos[key],
wait_label=f"Uploading video {i}",
diff --git a/comfy_api_nodes/nodes_krea.py b/comfy_api_nodes/nodes_krea.py
new file mode 100644
index 000000000..003a8a654
--- /dev/null
+++ b/comfy_api_nodes/nodes_krea.py
@@ -0,0 +1,290 @@
+"""Krea image-generation nodes."""
+
+import re
+
+from typing_extensions import override
+
+from comfy_api.latest import IO, ComfyExtension, Input
+from comfy_api_nodes.apis.krea import (
+ KreaAssetResponse,
+ KreaGenerateImageRequest,
+ KreaImageStyleReference,
+ KreaJob,
+ KreaMoodboard,
+)
+from comfy_api_nodes.util import (
+ ApiEndpoint,
+ download_url_to_image_tensor,
+ poll_op,
+ sync_op,
+ tensor_to_bytesio,
+ validate_string,
+)
+
+
+class KreaIO:
+ STYLE_REF = "KREA_STYLE_REF"
+
+
+async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Image) -> str:
+ """Upload an image to Krea's /assets endpoint and return the Krea-hosted image URL."""
+ img_io = tensor_to_bytesio(image, total_pixels=2048 * 2048, mime_type="image/png")
+ response = await sync_op(
+ cls,
+ endpoint=ApiEndpoint(path="/proxy/krea/assets", method="POST"),
+ response_model=KreaAssetResponse,
+ files=[("file", (img_io.name, img_io, "image/png"))],
+ content_type="multipart/form-data",
+ max_retries=1,
+ wait_label="Uploading reference",
+ )
+ return response.image_url
+
+
+_MODEL_MEDIUM = "Krea 2 Medium"
+_MODEL_LARGE = "Krea 2 Large"
+_MODEL_ENDPOINTS: dict[str, str] = {
+ _MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
+ _MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
+}
+
+_ASPECT_RATIOS = ["1:1", "4:3", "3:2", "16:9", "2.35:1", "4:5", "2:3", "9:16"]
+_RESOLUTIONS = ["1K"]
+_CREATIVITY_LEVELS = ["raw", "low", "medium", "high"]
+_KREA_QUEUED_STATUSES = ["backlogged", "queued", "scheduled"]
+
+_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
+
+
+def _krea_model_inputs() -> list:
+ """Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo."""
+ return [
+ IO.Combo.Input(
+ "aspect_ratio",
+ options=_ASPECT_RATIOS,
+ tooltip="Output aspect ratio.",
+ ),
+ IO.Combo.Input(
+ "resolution",
+ options=_RESOLUTIONS,
+ tooltip="Resolution scale.",
+ ),
+ IO.Combo.Input(
+ "creativity",
+ options=_CREATIVITY_LEVELS,
+ default="medium",
+ tooltip="Prompt interpretation strength: raw stays closest to the prompt; high is most creative.",
+ ),
+ IO.String.Input(
+ "moodboard_id",
+ default="",
+ tooltip="Optional Krea moodboard UUID (e.g. from the Krea website). "
+ "Leave empty to disable. Only one moodboard is supported per request.",
+ optional=True,
+ ),
+ IO.Float.Input(
+ "moodboard_strength",
+ default=0.35,
+ min=-0.5,
+ max=1.5,
+ step=0.05,
+ tooltip="Moodboard influence; ignored when moodboard_id is empty.",
+ optional=True,
+ ),
+ IO.Custom(KreaIO.STYLE_REF).Input(
+ "style_reference",
+ optional=True,
+ tooltip="Optional chain of style references (max 10) from Krea 2 Style Reference nodes.",
+ ),
+ ]
+
+
+class Krea2ImageNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="Krea2ImageNode",
+ display_name="Krea 2 Image",
+ category="api node/image/Krea",
+ description=(
+ "Generate images via Krea 2 — pick Medium (expressive illustrations) or "
+ "Large (expressive photorealism). Supports an optional moodboard and up "
+ "to 10 chained image style references."
+ ),
+ inputs=[
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ default="",
+ tooltip="Text prompt for the image.",
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
+ IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
+ ],
+ tooltip="Krea 2 Medium is best for expressive illustrations; "
+ "Krea 2 Large is best for expressive photorealism.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ control_after_generate=True,
+ tooltip="Random seed for reproducibility.",
+ ),
+ ],
+ outputs=[IO.Image.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(
+ widgets=["model", "model.moodboard_id"],
+ inputs=["model.style_reference"],
+ ),
+ expr="""
+ (
+ $isLarge := widgets.model = "krea 2 large";
+ $hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
+ $hasStyle := $lookup(inputs, "model.style_reference").connected;
+ $usd := $hasMoodboard
+ ? ($isLarge ? 0.07 : 0.04)
+ : ($hasStyle
+ ? ($isLarge ? 0.065 : 0.035)
+ : ($isLarge ? 0.06 : 0.03));
+ {"type":"usd","usd": $usd}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ prompt: str,
+ model: dict,
+ seed: int,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, strip_whitespace=False, min_length=1)
+
+ model_choice = model["model"]
+ endpoint_path = _MODEL_ENDPOINTS.get(model_choice)
+ if endpoint_path is None:
+ raise ValueError(f"Unknown Krea 2 model: {model_choice!r}")
+
+ moodboards: list[KreaMoodboard] | None = None
+ mb_id = (model.get("moodboard_id") or "").strip()
+ if mb_id:
+ if not _UUID_RE.match(mb_id):
+ raise ValueError(f"moodboard_id must be a UUID (received {mb_id!r}); copy it from the Krea website.")
+ mb_strength = model.get("moodboard_strength")
+ moodboards = [KreaMoodboard(id=mb_id, strength=0.35 if mb_strength is None else float(mb_strength))]
+
+ style_reference = model.get("style_reference")
+ image_style_references: list[KreaImageStyleReference] | None = None
+ if style_reference:
+ if len(style_reference) > 10:
+ raise ValueError(f"Krea 2 accepts at most 10 image_style_references; received {len(style_reference)}.")
+ image_style_references = [
+ KreaImageStyleReference(url=ref["url"], strength=float(ref["strength"])) for ref in style_reference
+ ]
+ initial = await sync_op(
+ cls,
+ ApiEndpoint(path=endpoint_path, method="POST"),
+ response_model=KreaJob,
+ data=KreaGenerateImageRequest(
+ prompt=prompt,
+ aspect_ratio=model["aspect_ratio"],
+ resolution=model["resolution"],
+ seed=seed,
+ creativity=model["creativity"],
+ moodboards=moodboards,
+ image_style_references=image_style_references,
+ ),
+ )
+ job = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/krea/jobs/{initial.job_id}", method="GET"),
+ response_model=KreaJob,
+ status_extractor=lambda r: r.status,
+ queued_statuses=_KREA_QUEUED_STATUSES,
+ )
+ if not job.result or not job.result.urls:
+ raise RuntimeError(f"Krea 2 job {job.job_id} completed without any image URLs.")
+ image = await download_url_to_image_tensor(job.result.urls[0])
+ return IO.NodeOutput(image)
+
+
+class Krea2StyleReferenceNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="Krea2StyleReferenceNode",
+ display_name="Krea 2 Style Reference",
+ category="api node/image/Krea",
+ description=(
+ "Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 "
+ "Style Reference nodes (max 10) and feed the final `style_reference` output "
+ "into Krea 2 Image. Each image is uploaded to ComfyAPI storage and passed as URL."
+ ),
+ inputs=[
+ IO.Image.Input(
+ "image",
+ tooltip="Reference image whose style influences the generation.",
+ ),
+ IO.Float.Input(
+ "strength",
+ default=1.0,
+ min=-2.0,
+ max=2.0,
+ step=0.05,
+ tooltip="Reference strength; negative values invert the style influence.",
+ ),
+ IO.Custom(KreaIO.STYLE_REF).Input(
+ "style_reference",
+ optional=True,
+ tooltip="Optional incoming chain of style references; this node appends one more.",
+ ),
+ ],
+ outputs=[IO.Custom(KreaIO.STYLE_REF).Output(display_name="style_reference")],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ image: Input.Image,
+ strength: float,
+ style_reference: list[dict] | None = None,
+ ) -> IO.NodeOutput:
+ chain: list[dict] = list(style_reference) if style_reference else []
+ if len(chain) >= 10:
+ raise ValueError("Krea 2 accepts at most 10 image_style_references in one generation.")
+ url = await _upload_image_to_krea_assets(cls, image)
+ chain.append({"url": url, "strength": float(strength)})
+ return IO.NodeOutput(chain)
+
+
+class KreaExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ Krea2ImageNode,
+ Krea2StyleReferenceNode,
+ ]
+
+
+async def comfy_entrypoint() -> KreaExtension:
+ return KreaExtension()
diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py
index c47f3c79b..479ee8a53 100644
--- a/comfy_execution/graph.py
+++ b/comfy_execution/graph.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
from typing import Type, Literal
import nodes
diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py
index f951a3350..731b8dc66 100644
--- a/comfy_execution/progress.py
+++ b/comfy_execution/progress.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
from typing import TypedDict, Dict, Optional, Tuple
from typing_extensions import override
from PIL import Image
diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py
index e73624bd1..ae9a2376c 100644
--- a/comfy_execution/validation.py
+++ b/comfy_execution/validation.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
from comfy_api.latest import IO
diff --git a/comfy_extras/mediapipe/face_geometry.py b/comfy_extras/mediapipe/face_geometry.py
index 04b2b0557..4f3813430 100644
--- a/comfy_extras/mediapipe/face_geometry.py
+++ b/comfy_extras/mediapipe/face_geometry.py
@@ -2,7 +2,6 @@
+ weighted Procrustes solver. Computes the 4x4 facial transformation matrix.
"""
-from __future__ import annotations
import math
import numpy as np
diff --git a/comfy_extras/mediapipe/face_landmarker.py b/comfy_extras/mediapipe/face_landmarker.py
index 6a9a25f82..95c67b321 100644
--- a/comfy_extras/mediapipe/face_landmarker.py
+++ b/comfy_extras/mediapipe/face_landmarker.py
@@ -1,7 +1,6 @@
"""Pure-PyTorch port of MediaPipe's face_landmarker_v2_with_blendshapes.task:
BlazeFace detector → FaceMesh v2 → ARKit-52 blendshapes."""
-from __future__ import annotations
import math
from functools import lru_cache
diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py
index d5084497e..f09a8a874 100644
--- a/comfy_extras/nodes_audio.py
+++ b/comfy_extras/nodes_audio.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import av
import torchaudio
import torch
diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py
index 4ebb4b51e..b585c560f 100644
--- a/comfy_extras/nodes_cfg.py
+++ b/comfy_extras/nodes_cfg.py
@@ -57,24 +57,55 @@ class CFGNorm(io.ComfyNode):
inputs=[
io.Model.Input("model"),
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
+ io.Boolean.Input(
+ "pre_cfg",
+ default=False,
+ optional=True,
+ tooltip=(
+ "If true, rescale the combined noise BEFORE the sampler's CFG combine, "
+ "without clamping (can amplify). Matches the norm-scaled CFG used by "
+ "models like Lens. Default false keeps the original post-CFG x0-space "
+ "attenuate-only behavior."
+ ),
+ ),
],
outputs=[io.Model.Output(display_name="patched_model")],
is_experimental=True,
)
@classmethod
- def execute(cls, model, strength) -> io.NodeOutput:
+ def execute(cls, model, strength, pre_cfg=False) -> io.NodeOutput:
m = model.clone()
- def cfg_norm(args):
- cond_p = args['cond_denoised']
- pred_text_ = args["denoised"]
+ if pre_cfg:
+ def cfg_norm_pre(args):
+ cond = args["cond"]
+ uncond = args["uncond"]
+ cond_scale = args["cond_scale"]
+ comb = uncond + cond_scale * (cond - uncond)
+ cond_norm = torch.linalg.vector_norm(cond, dim=1, keepdim=True)
+ comb_norm = torch.linalg.vector_norm(comb, dim=1, keepdim=True)
+ rescale = torch.where(
+ comb_norm > 0,
+ cond_norm / comb_norm.clamp_min(1e-12),
+ torch.ones_like(comb_norm),
+ )
+ rescaled = comb * rescale
+ # strength blends back toward standard linear CFG (1.0 = full rescale).
+ if strength != 1.0:
+ rescaled = strength * rescaled + (1.0 - strength) * comb
+ return rescaled
+ m.set_model_sampler_cfg_function(cfg_norm_pre)
+ else:
+ def cfg_norm(args):
+ cond_p = args['cond_denoised']
+ pred_text_ = args["denoised"]
- norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
- norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
- scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
- return pred_text_ * scale * strength
+ norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
+ norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
+ scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
+ return pred_text_ * scale * strength
- m.set_model_sampler_post_cfg_function(cfg_norm)
+ m.set_model_sampler_post_cfg_function(cfg_norm)
return io.NodeOutput(m)
diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py
index f7ca833dc..24729c3a7 100644
--- a/comfy_extras/nodes_context_windows.py
+++ b/comfy_extras/nodes_context_windows.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
from comfy_api.latest import ComfyExtension, io
import comfy.context_windows
import nodes
diff --git a/comfy_extras/nodes_curve.py b/comfy_extras/nodes_curve.py
index 9803e8034..099453131 100644
--- a/comfy_extras/nodes_curve.py
+++ b/comfy_extras/nodes_curve.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import numpy as np
from comfy_api.latest import ComfyExtension, io
diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py
index 33933229d..fe6008aa3 100644
--- a/comfy_extras/nodes_images.py
+++ b/comfy_extras/nodes_images.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import nodes
import folder_paths
diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py
index 342cadb69..92507f1fc 100644
--- a/comfy_extras/nodes_logic.py
+++ b/comfy_extras/nodes_logic.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
from typing import TypedDict
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py
index 51cf7951f..48d75c9e5 100644
--- a/comfy_extras/nodes_lt.py
+++ b/comfy_extras/nodes_lt.py
@@ -226,10 +226,20 @@ def get_noise_mask(latent):
noise_mask = noise_mask.clone()
return noise_mask
-def get_keyframe_idxs(cond):
+def get_keyframe_idxs(cond, latent_shape=None):
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
if keyframe_idxs is None:
return None, 0
+ # Get number of keyframes from latent_shape or guide_attention_entries if available
+ if latent_shape is not None and len(latent_shape) == 5:
+ tokens_per_frame = latent_shape[-2] * latent_shape[-1]
+ num_keyframes = keyframe_idxs.shape[2] // tokens_per_frame
+ return keyframe_idxs, num_keyframes
+ entries = conditioning_get_any_value(cond, "guide_attention_entries", None)
+ if entries:
+ num_keyframes = sum(e["latent_shape"][0] for e in entries)
+ return keyframe_idxs, num_keyframes
+ # fallback, may under-count if keyframes share t-start
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
return keyframe_idxs, num_keyframes
@@ -322,9 +332,9 @@ class LTXVAddGuide(io.ComfyNode):
return factor
@classmethod
- def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
+ def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors, latent_shape=None):
time_scale_factor, _, _ = scale_factors
- _, num_keyframes = get_keyframe_idxs(cond)
+ _, num_keyframes = get_keyframe_idxs(cond, latent_shape)
latent_count = latent_length - num_keyframes
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
if guide_length > 1 and frame_idx != 0:
@@ -436,7 +446,7 @@ class LTXVAddGuide(io.ComfyNode):
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
resolved_frame_idx = frame_idx
if frame_idx < 0:
- _, num_keyframes = get_keyframe_idxs(positive)
+ _, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0)
causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1
@@ -454,7 +464,7 @@ class LTXVAddGuide(io.ComfyNode):
if latent_downscale_factor > 1:
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
- frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
+ frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors, latent_shape=latent_image.shape)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
positive, negative, latent_image, noise_mask = cls.append_keyframe(
@@ -506,7 +516,7 @@ class LTXVCropGuides(io.ComfyNode):
latent_image = latent["samples"].clone()
noise_mask = get_noise_mask(latent)
- _, num_keyframes = get_keyframe_idxs(positive)
+ _, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
if num_keyframes == 0:
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py
index 06aefa475..0040d1a92 100644
--- a/comfy_extras/nodes_math.py
+++ b/comfy_extras/nodes_math.py
@@ -4,7 +4,6 @@ Provides a ComfyMathExpression node that evaluates math expressions
against dynamically-grown numeric inputs.
"""
-from __future__ import annotations
import math
import string
diff --git a/comfy_extras/nodes_mediapipe.py b/comfy_extras/nodes_mediapipe.py
index 6b7916aee..32dc22de3 100644
--- a/comfy_extras/nodes_mediapipe.py
+++ b/comfy_extras/nodes_mediapipe.py
@@ -10,7 +10,6 @@ Custom IO types:
MediaPipeFaceLandmarker also emits the core BOUNDING_BOX type — pair with DrawBBoxes.
"""
-from __future__ import annotations
import numpy as np
import torch
diff --git a/comfy_extras/nodes_moge.py b/comfy_extras/nodes_moge.py
index 3508781a0..79aec5d7f 100644
--- a/comfy_extras/nodes_moge.py
+++ b/comfy_extras/nodes_moge.py
@@ -1,6 +1,5 @@
"""ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration."""
-from __future__ import annotations
import torch
diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py
new file mode 100644
index 000000000..d2f6fe67a
--- /dev/null
+++ b/comfy_extras/nodes_multigpu.py
@@ -0,0 +1,408 @@
+from __future__ import annotations
+
+import copy
+import logging
+from inspect import cleandoc
+from typing import TYPE_CHECKING
+from typing_extensions import override
+
+from comfy_api.latest import ComfyExtension, io
+
+if TYPE_CHECKING:
+ from comfy.model_patcher import ModelPatcher
+ from comfy.sd import CLIP, VAE
+import torch
+
+import comfy.model_management
+import comfy.multigpu
+
+
+class MultiGPUCFGSplitNode(io.ComfyNode):
+ """
+ Prepares model to have sampling accelerated via splitting work units.
+
+ Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
+
+ Other than those exceptions, this node can be placed in any order.
+ """
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="MultiGPU_WorkUnits",
+ display_name="MultiGPU CFG Split",
+ category="advanced/multigpu",
+ description=cleandoc(cls.__doc__),
+ inputs=[
+ io.Model.Input("model"),
+ io.Int.Input("max_gpus", default=2, min=1, step=1),
+ ],
+ outputs=[
+ io.Model.Output(),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
+ model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
+ return io.NodeOutput(model)
+
+
+def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device):
+ """Cast compute dtype to one the device supports; no-op if already supported."""
+ weight_dtype = patcher.model_dtype()
+ cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device)
+ if cast_dtype is None:
+ return
+ logging.info(f"Select Model Device: using {cast_dtype} compute dtype on {device} (model weight dtype was {weight_dtype}).")
+ patcher.set_model_compute_dtype(cast_dtype)
+
+
+def _remember_base_devices(patcher: ModelPatcher):
+ """Stash the original load/offload device on the underlying model.
+
+ Stored on patcher.model (which is shared with the input patcher), so
+ later "default" selections can recover the loader's original routing.
+ Only the first Select on a given chain writes these attrs; subsequent
+ deepclones inherit them onto their freshly-loaded model below.
+ """
+ if not hasattr(patcher.model, "_select_base_load_device"):
+ patcher.model._select_base_load_device = patcher.load_device
+ patcher.model._select_base_offload_device = patcher.offload_device
+
+
+def _propagate_base_devices(src_model, dst_model):
+ """Carry the loader-original device attrs onto the freshly-deepcloned model."""
+ if hasattr(src_model, "_select_base_load_device") and not hasattr(dst_model, "_select_base_load_device"):
+ dst_model._select_base_load_device = src_model._select_base_load_device
+ dst_model._select_base_offload_device = src_model._select_base_offload_device
+
+
+def _retarget_patcher(patcher: ModelPatcher, target_load_device, target_offload_device):
+ """Return a patcher whose actual model weights live on *target_load_device*.
+
+ If *patcher* is already on *target_load_device* we just retarget the
+ (already-cloned) patcher's metadata in place. Otherwise we call
+ :meth:`ModelPatcher.deepclone_multigpu` to spawn a fresh model from
+ the loader's ``cached_patcher_init`` factory -- the only safe way to
+ move weights that may already be partially loaded onto another device.
+
+ NOTE: reusing the input patcher's model when the requested device
+ matches its current load_device is a deliberate fast path. Anything
+ that has already mutated the original model (e.g. a prior KSampler
+ invocation on the same model) will be observed here. This is by
+ design and documented on the SelectXDeviceNode docstrings -- placing
+ Select X Device after a node that consumes the same model is not
+ recommended.
+ """
+ if patcher.load_device == target_load_device:
+ # Fast path: weights already on the desired device, just update offload.
+ patcher.offload_device = target_offload_device
+ return patcher
+ src_model = patcher.model
+ patcher = patcher.deepclone_multigpu(new_load_device=target_load_device)
+ patcher.offload_device = target_offload_device
+ _propagate_base_devices(src_model, patcher.model)
+ if hasattr(patcher, "register_load_device"):
+ patcher.register_load_device(patcher.load_device)
+ return patcher
+
+
+def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None):
+ """Resolve the requested device and produce a patcher routed there.
+
+ For "default" we restore the loader's original load/offload pair.
+ For CPU we pin both load and offload to CPU (and, on a dynamic
+ patcher, downgrade to a plain ModelPatcher so the dynamic-only
+ code paths are bypassed).
+ For an explicit GPU we keep the loader's original offload but
+ target the requested load device; if that differs from the current
+ load device the patcher is deepcloned onto the new device.
+ """
+ _remember_base_devices(patcher)
+ base_load = patcher.model._select_base_load_device
+ base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device
+
+ if resolved is None:
+ # "default" -> route back to the loader's original devices.
+ return _retarget_patcher(patcher, base_load, base_offload)
+ if resolved.type == "cpu":
+ if patcher.is_dynamic():
+ # clone(disable_dynamic=True) requires cached_patcher_init; let the
+ # exception surface to the caller (Select*DeviceNode.execute), which
+ # will translate it into a passthrough+log so unsupported loaders
+ # don't hard-fail the workflow.
+ patcher = patcher.clone(disable_dynamic=True)
+ patcher.load_device = resolved
+ patcher.offload_device = resolved
+ return patcher
+ return _retarget_patcher(patcher, resolved, base_offload)
+
+
+def _prune_multigpu_collision(model: ModelPatcher, primary_device):
+ """Drop any multigpu clone whose load_device matches *primary_device*.
+
+ Without pruning, MultiGPU CFG Split would have stacked a clone on
+ the same device the primary now occupies (i.e. the workflow places
+ MultiGPU CFG Split before Select Model Device). Keeps the clone set
+ consistent with the new primary placement.
+ """
+ multigpu_models = model.get_additional_models_with_key("multigpu")
+ if not multigpu_models:
+ return
+ filtered = [m for m in multigpu_models if m.load_device != primary_device]
+ if len(filtered) != len(multigpu_models):
+ logging.info(f"Select Model Device: pruning MultiGPU clone on {primary_device} that now collides with the primary model.")
+ model.set_additional_models("multigpu", filtered)
+ if hasattr(model, "match_multigpu_clones"):
+ model.match_multigpu_clones()
+
+
+class SelectModelDeviceNode(io.ComfyNode):
+ """
+ Place the diffusion model on a specific device (default / cpu / gpu:N).
+
+ - "default" restores the device assigned by the loader (even after a
+ prior Select Model Device call).
+ - "cpu" pins both the load and offload device to CPU.
+ - "gpu:N" pins the load device to the Nth available GPU; the offload
+ device is restored to the loader's original choice.
+
+ When the requested device differs from the device the input model is
+ already on, a fresh model is spawned via the loader's reload factory
+ (cached_patcher_init) so the new patcher owns independent weights on
+ the new device. Loaders that don't support multigpu (no factory) will
+ cause the node to pass through unchanged with a warning.
+
+ If the workflow already has MultiGPU CFG Split applied and the chosen
+ GPU collides with one of the existing multigpu clones, that clone is
+ dropped so two patchers don't end up bound to the same device.
+
+ When the selected device does not exist on the current machine
+ (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
+ the node passes the model through unchanged and logs a message
+ instead of failing.
+
+ NOTE: Placing Select Model Device *after* a node that has already
+ consumed the same model (e.g. a KSampler that ran on this model on
+ the original device) is not recommended -- any state the prior
+ consumer mutated on the original model will be observed when the
+ selected device matches the original (fast path). Place Select Model
+ Device before any consumer of the model.
+ """
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SelectModelDevice",
+ display_name="Select Model Device",
+ category="advanced/multigpu",
+ description=cleandoc(cls.__doc__),
+ inputs=[
+ io.Model.Input("model"),
+ io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()),
+ ],
+ outputs=[
+ io.Model.Output(),
+ ],
+ )
+
+ @classmethod
+ def validate_inputs(cls, device="default"):
+ # Allow unknown gpu:N values so portable workflows do not error
+ # at validation time; runtime fallback will handle them.
+ return True
+
+ @classmethod
+ def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput:
+ model = model.clone()
+ resolved = comfy.model_management.resolve_gpu_device_option(device)
+ if resolved is None and device not in (None, "default"):
+ logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.")
+ return io.NodeOutput(model)
+ try:
+ model = _apply_patcher_device(model, resolved)
+ except RuntimeError as e:
+ logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})")
+ return io.NodeOutput(model)
+ if resolved is not None:
+ _force_supported_compute_dtype(model, resolved)
+ _prune_multigpu_collision(model, model.load_device)
+ return io.NodeOutput(model)
+
+
+class SelectCLIPDeviceNode(io.ComfyNode):
+ """
+ Place the CLIP text encoder on a specific device (default / cpu / gpu:N).
+
+ - "default" restores the device assigned by the loader.
+ - "cpu" pins both the load and offload device to CPU.
+ - "gpu:N" pins the load device to the Nth available GPU.
+
+ When the selected device does not exist on the current machine
+ (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
+ the node passes the CLIP through unchanged and logs a message
+ instead of failing.
+ """
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SelectCLIPDevice",
+ display_name="Select CLIP Device",
+ category="advanced/multigpu",
+ description=cleandoc(cls.__doc__),
+ inputs=[
+ io.Clip.Input("clip"),
+ io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()),
+ ],
+ outputs=[
+ io.Clip.Output(),
+ ],
+ )
+
+ @classmethod
+ def validate_inputs(cls, device="default"):
+ return True
+
+ @classmethod
+ def execute(cls, clip: CLIP, device: str = "default") -> io.NodeOutput:
+ clip = clip.clone()
+ resolved = comfy.model_management.resolve_gpu_device_option(device)
+ if resolved is None and device not in (None, "default"):
+ logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.")
+ return io.NodeOutput(clip)
+ try:
+ clip.patcher = _apply_patcher_device(clip.patcher, resolved)
+ except RuntimeError as e:
+ logging.warning(f"Select CLIP Device: cannot retarget CLIP, passing through unchanged. ({e})")
+ return io.NodeOutput(clip)
+
+
+class SelectVAEDeviceNode(io.ComfyNode):
+ """
+ Place the VAE on a specific device (default / gpu:N).
+
+ - "default" restores the device assigned by the loader.
+ - "gpu:N" pins the load device to the Nth available GPU; the offload
+ device is set to the standard VAE offload device.
+
+ CPU is intentionally not exposed in the UI for the VAE; if a workflow
+ supplies "cpu" anyway (e.g. opened from another machine), the request
+ is dropped with a log message and the VAE is passed through unchanged.
+
+ When the selected device does not exist on the current machine
+ (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box),
+ the node passes the VAE through unchanged and logs a message
+ instead of failing.
+ """
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="SelectVAEDevice",
+ display_name="Select VAE Device",
+ category="advanced/multigpu",
+ description=cleandoc(cls.__doc__),
+ inputs=[
+ io.Vae.Input("vae"),
+ io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options_no_cpu()),
+ ],
+ outputs=[
+ io.Vae.Output(),
+ ],
+ )
+
+ @classmethod
+ def validate_inputs(cls, device="default"):
+ return True
+
+ @classmethod
+ def execute(cls, vae: VAE, device: str = "default") -> io.NodeOutput:
+ # VAE has no .clone(); shallow-copy the wrapper and clone the patcher
+ # so we can retarget load/offload device without affecting the input VAE.
+ vae = copy.copy(vae)
+ vae.patcher = vae.patcher.clone()
+ resolved = comfy.model_management.resolve_gpu_device_option(device)
+ if resolved is None and device not in (None, "default"):
+ logging.info(f"Select VAE Device: requested device '{device}' not available, passing through unchanged.")
+ return io.NodeOutput(vae)
+ if resolved is not None and resolved.type == "cpu":
+ logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.")
+ return io.NodeOutput(vae)
+ if not hasattr(vae, "_select_base_device"):
+ vae._select_base_device = vae.device
+ try:
+ vae.patcher = _apply_patcher_device(
+ vae.patcher, resolved,
+ base_offload_override=comfy.model_management.vae_offload_device(),
+ )
+ except RuntimeError as e:
+ logging.warning(f"Select VAE Device: cannot retarget VAE, passing through unchanged. ({e})")
+ return io.NodeOutput(vae)
+ # Keep VAE wrapper in sync with whatever model the patcher now owns;
+ # deepclone_multigpu may have produced a fresh first_stage_model.
+ vae.first_stage_model = vae.patcher.model
+ vae.device = vae._select_base_device if resolved is None else resolved
+ return io.NodeOutput(vae)
+
+
+class MultiGPUOptionsNode(io.ComfyNode):
+ """
+ Select the relative speed of GPUs in the special case they have significantly different performance from one another.
+
+ NOTE (not registered yet, see MultiGPUExtension.get_node_list below):
+ The output GPUOptionsGroup is plumbed through create_multigpu_deepclones() and stored on
+ model.model_options['multigpu_options'] via GPUOptionsGroup.register(), but the cond
+ scheduler in comfy/samplers.py (calc_cond_batch_outer_multigpu) does NOT yet consult
+ relative_speed when distributing conds across devices; it uses a uniform conds_per_device
+ round-robin via next_available_device(). Before re-enabling this node, wire its
+ relative_speed into the scheduler (e.g. via comfy.multigpu.load_balance_devices(),
+ which already implements the proportional split) so the input actually affects work
+ distribution.
+ """
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="MultiGPU_Options",
+ display_name="MultiGPU Options",
+ category="advanced/multigpu",
+ description=cleandoc(cls.__doc__),
+ inputs=[
+ io.Int.Input("device_index", default=0, min=0, max=64),
+ io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01),
+ io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True),
+ ],
+ outputs=[
+ io.Custom("GPU_OPTIONS").Output(),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput:
+ if not gpu_options:
+ gpu_options = comfy.multigpu.GPUOptionsGroup()
+ else:
+ gpu_options = gpu_options.clone()
+
+ opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed)
+ gpu_options.add(opt)
+
+ return io.NodeOutput(gpu_options)
+
+
+class MultiGPUExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ MultiGPUCFGSplitNode,
+ SelectModelDeviceNode,
+ SelectCLIPDeviceNode,
+ SelectVAEDeviceNode,
+ # MultiGPUOptionsNode,
+ ]
+
+
+async def comfy_entrypoint() -> MultiGPUExtension:
+ return MultiGPUExtension()
diff --git a/comfy_extras/nodes_number_convert.py b/comfy_extras/nodes_number_convert.py
index e38a33c15..01593b6e6 100644
--- a/comfy_extras/nodes_number_convert.py
+++ b/comfy_extras/nodes_number_convert.py
@@ -4,7 +4,6 @@ Provides a single node that converts INT, FLOAT, STRING, and BOOL
inputs into FLOAT and INT outputs.
"""
-from __future__ import annotations
import math
diff --git a/comfy_extras/nodes_painter.py b/comfy_extras/nodes_painter.py
index e104c8480..df7a0b76a 100644
--- a/comfy_extras/nodes_painter.py
+++ b/comfy_extras/nodes_painter.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import hashlib
import os
diff --git a/comfy_extras/nodes_pid.py b/comfy_extras/nodes_pid.py
new file mode 100644
index 000000000..811b9ae8e
--- /dev/null
+++ b/comfy_extras/nodes_pid.py
@@ -0,0 +1,55 @@
+"""PiD (Pixel Diffusion Decoder) node"""
+
+import torch
+from typing_extensions import override
+
+import node_helpers
+import comfy.latent_formats
+from comfy_api.latest import ComfyExtension, io
+
+
+class PiDConditioning(io.ComfyNode):
+ @classmethod
+ def define_schema(cls) -> io.Schema:
+ return io.Schema(
+ node_id="PiDConditioning",
+ display_name="PiD Conditioning",
+ category="advanced/conditioning",
+ description=(
+ "Attaches a latent and a degrade_sigma scalar to a CONDITIONING for PiD decoding/upscaling"
+ ),
+ inputs=[
+ io.Conditioning.Input("positive"),
+ io.Latent.Input("latent", tooltip="latent (from VAEEncode or a KSampler)."),
+ io.Combo.Input("latent_format", options=["flux", "sd3"], default="flux",
+ tooltip="Flux1 and Flux2 latents auto-detected from channel dim, sd3 has to be selected manually."),
+ io.Float.Input(
+ "degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01,
+ tooltip="0 = clean latent. Increase to denoise corrupted latent outputs.",
+ ),
+ ],
+ outputs=[io.Conditioning.Output()],
+ )
+
+ @classmethod
+ def execute(cls, positive, latent, latent_format: str, degrade_sigma: float) -> io.NodeOutput:
+ samples = latent["samples"]
+ if latent_format == "flux":
+ fmt_cls = comfy.latent_formats.Flux2 if samples.shape[1] == 128 else comfy.latent_formats.Flux
+ else:
+ fmt_cls = comfy.latent_formats.SD3
+ lq_latent = fmt_cls().process_in(samples)
+ sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32)
+ return io.NodeOutput(node_helpers.conditioning_set_values(
+ positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t},
+ ))
+
+
+class PiDExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [PiDConditioning]
+
+
+async def comfy_entrypoint() -> PiDExtension:
+ return PiDExtension()
diff --git a/comfy_extras/nodes_resolution.py b/comfy_extras/nodes_resolution.py
index 520b4067e..1628038cc 100644
--- a/comfy_extras/nodes_resolution.py
+++ b/comfy_extras/nodes_resolution.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
import math
from enum import Enum
from typing_extensions import override
diff --git a/comfy_extras/nodes_toolkit.py b/comfy_extras/nodes_toolkit.py
index ae802896b..0548a0cf8 100644
--- a/comfy_extras/nodes_toolkit.py
+++ b/comfy_extras/nodes_toolkit.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py
index 78a2a28f8..ae1d826d5 100644
--- a/comfy_extras/nodes_video.py
+++ b/comfy_extras/nodes_video.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import os
import av
import torch
diff --git a/folder_paths.py b/folder_paths.py
index 36d61fcd0..7304e1b73 100644
--- a/folder_paths.py
+++ b/folder_paths.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import os
import time
import mimetypes
diff --git a/main.py b/main.py
index 26d523c30..bce451a83 100644
--- a/main.py
+++ b/main.py
@@ -218,7 +218,7 @@ import comfy.model_patcher
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
- elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
+ elif comfy_aimdo.control.init_devices(d.index for d in comfy.model_management.get_all_torch_devices()):
if args.verbose == 'DEBUG':
comfy_aimdo.control.set_log_debug()
elif args.verbose == 'CRITICAL':
diff --git a/nodes.py b/nodes.py
index 820eeef4c..155ecfe9c 100644
--- a/nodes.py
+++ b/nodes.py
@@ -1,4 +1,3 @@
-from __future__ import annotations
import torch
@@ -795,6 +794,7 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
metadata = None
+ vae_path = None
if vae_name == "pixel_space":
sd = {}
sd["pixel_space_vae"] = torch.tensor(1.0)
@@ -813,6 +813,14 @@ class VAELoader:
metadata["tae_latent_channels"] = 128
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
vae.throw_exception_if_invalid()
+ # Register a reload factory on the patcher so multigpu deepclones
+ # (Select VAE Device, future MultiGPU VAE work-units) can produce
+ # per-device clones from the same loader context. Only set when we
+ # actually have a single backing file -- pixel_space and the
+ # image TAESDs (composed from separate encoder/decoder files via
+ # load_taesd) are not addressable by a single vae_path.
+ if vae_path is not None:
+ vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, None))
return (vae,)
class ControlNetLoader:
@@ -961,7 +969,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
- "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ),
+ "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -971,7 +979,7 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
- DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
+ DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b\n pixeldit: gemma 2 2B elm"
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
@@ -2389,6 +2397,7 @@ async def init_builtin_extra_nodes():
"nodes_lt_audio.py",
"nodes_lt.py",
"nodes_hooks.py",
+ "nodes_multigpu.py",
"nodes_load_3d.py",
"nodes_cosmos.py",
"nodes_video.py",
@@ -2411,6 +2420,7 @@ async def init_builtin_extra_nodes():
"nodes_context_windows.py",
"nodes_qwen.py",
"nodes_chroma_radiance.py",
+ "nodes_pid.py",
"nodes_model_patch.py",
"nodes_easycache.py",
"nodes_audio_encoder.py",
diff --git a/openapi.yaml b/openapi.yaml
index 502e518c7..f801a39d9 100644
--- a/openapi.yaml
+++ b/openapi.yaml
@@ -275,7 +275,10 @@ paths:
responses:
"200":
description: Queue updated
-
+ content:
+ application/json:
+ schema:
+ $ref: "#/components/schemas/QueueManageResponse"
'400':
description: Invalid request parameters
content:
@@ -3092,18 +3095,34 @@ paths:
application/json:
schema:
type: object
- required:
- - asset_ids
properties:
+ job_ids:
+ type: array
+ items:
+ type: string
+ description: Job IDs whose associated assets should all be included in the ZIP bundle.
asset_ids:
type: array
items:
type: string
format: uuid
- description: IDs of assets to export
+ description: Asset IDs to include in the ZIP bundle. Additive to assets associated with provided job IDs.
export_name:
type: string
description: Name for the export archive
+ naming_strategy:
+ type: string
+ enum: [group_by_job_id, preserve, asset_id, group_by_job_time]
+ default: group_by_job_time
+ description: "Strategy for naming files in the ZIP: group by job ID, preserve original names, use the asset ID, or group by job creation time."
+ job_asset_name_filters:
+ type: object
+ additionalProperties:
+ type: array
+ minItems: 1
+ items:
+ type: string
+ description: Optional per-job asset name filters. When provided for a job ID, only assets whose name matches one of the listed names are included.
responses:
"202":
description: Export task accepted
@@ -3575,10 +3594,7 @@ paths:
content:
application/json:
schema:
- type: array
- items:
- $ref: "#/components/schemas/HubLabel"
-
+ $ref: "#/components/schemas/HubLabelListResponse"
'400':
description: Bad request (e.g. invalid type parameter)
content:
@@ -7466,6 +7482,25 @@ components:
type: string
description: Array of prompt IDs to delete from queue
+ QueueManageResponse:
+ type: object
+ x-runtime: [cloud]
+ description: >-
+ [cloud-only] Result of a queue mutation. The Cloud runtime returns which
+ items were deleted and whether the queue was cleared; local ComfyUI
+ returns an empty 200 body.
+ properties:
+ deleted:
+ type: array
+ nullable: true
+ items:
+ type: string
+ description: Prompt IDs that were deleted from the queue.
+ cleared:
+ type: boolean
+ nullable: true
+ description: Whether the queue was cleared.
+
# -------------------------------------------------------------------
# History
# -------------------------------------------------------------------
@@ -7546,6 +7581,16 @@ components:
outputs_count:
type: integer
description: Total number of output files
+ workflow_id:
+ type: string
+ nullable: true
+ x-runtime: [cloud]
+ description: "[cloud-only] UUID of the Cloud workflow entity this job is associated with. Local ComfyUI returns null."
+ execution_error:
+ x-runtime: [cloud]
+ description: "[cloud-only] Detailed execution error from ComfyUI for failed jobs. Absent on local ComfyUI."
+ allOf:
+ - $ref: "#/components/schemas/ExecutionError"
JobDetailResponse:
type: object
@@ -10433,6 +10478,19 @@ components:
- custom_node
description: Label category.
+ HubLabelListResponse:
+ type: object
+ x-runtime: [cloud]
+ description: '[cloud-only] Response wrapper for the available Hub label catalog.'
+ required:
+ - labels
+ properties:
+ labels:
+ type: array
+ items:
+ $ref: '#/components/schemas/HubLabelInfo'
+ description: Available labels, optionally filtered by type.
+
HubProfileSummary:
type: object
x-runtime: [cloud]
diff --git a/requirements.txt b/requirements.txt
index 2ca6d8929..9308e29d4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.44.19
-comfyui-workflow-templates==0.9.82
+comfyui-workflow-templates==0.9.85
comfyui-embedded-docs==0.5.1
torch
torchsde
diff --git a/server.py b/server.py
index 44470b904..268441bd1 100644
--- a/server.py
+++ b/server.py
@@ -646,18 +646,37 @@ class PromptServer():
@routes.get("/system_stats")
async def system_stats(request):
- device = comfy.model_management.get_torch_device()
- device_name = comfy.model_management.get_torch_device_name(device)
+ primary_device = comfy.model_management.get_torch_device()
cpu_device = comfy.model_management.torch.device("cpu")
ram_total = comfy.model_management.get_total_memory(cpu_device)
ram_free = comfy.model_management.get_free_memory(cpu_device)
- vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
- vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
required_frontend_version = FrontendManager.get_required_frontend_version()
installed_templates_version = FrontendManager.get_installed_templates_version()
required_templates_version = FrontendManager.get_required_templates_version()
comfy_package_versions = FrontendManager.get_comfy_package_versions()
+ # Report every torch device visible to multigpu, with the primary
+ # device first so existing clients that read devices[0] keep working.
+ torch_devices = comfy.model_management.get_all_torch_devices()
+ if primary_device in torch_devices:
+ torch_devices = [primary_device] + [d for d in torch_devices if d != primary_device]
+ else:
+ torch_devices = [primary_device] + list(torch_devices)
+
+ device_entries = []
+ for d in torch_devices:
+ vram_total, torch_vram_total = comfy.model_management.get_total_memory(d, torch_total_too=True)
+ vram_free, torch_vram_free = comfy.model_management.get_free_memory(d, torch_free_too=True)
+ device_entries.append({
+ "name": comfy.model_management.get_torch_device_name(d),
+ "type": d.type,
+ "index": d.index,
+ "vram_total": vram_total,
+ "vram_free": vram_free,
+ "torch_vram_total": torch_vram_total,
+ "torch_vram_free": torch_vram_free,
+ })
+
system_stats = {
"system": {
"os": sys.platform,
@@ -673,17 +692,7 @@ class PromptServer():
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
"argv": sys.argv
},
- "devices": [
- {
- "name": device_name,
- "type": device.type,
- "index": device.index,
- "vram_total": vram_total,
- "vram_free": vram_free,
- "torch_vram_total": torch_vram_total,
- "torch_vram_free": torch_vram_free,
- }
- ]
+ "devices": device_entries
}
return web.json_response(system_stats)