diff --git a/app/assets/services/metadata_extract.py b/app/assets/services/metadata_extract.py index a004929bc..bdfe60218 100644 --- a/app/assets/services/metadata_extract.py +++ b/app/assets/services/metadata_extract.py @@ -4,7 +4,6 @@ Tier 1: Filesystem metadata (zero parsing) Tier 2: Safetensors header metadata (fast JSON read only) """ -from __future__ import annotations import json import logging diff --git a/app/custom_node_manager.py b/app/custom_node_manager.py index 281febca9..738af2abd 100644 --- a/app/custom_node_manager.py +++ b/app/custom_node_manager.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import os import folder_paths import glob diff --git a/app/frontend_management.py b/app/frontend_management.py index 483da2d29..8e84e8dd9 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -1,4 +1,3 @@ -from __future__ import annotations import argparse import logging import os diff --git a/app/model_manager.py b/app/model_manager.py index f124d1117..8f6e34b33 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import os import base64 import json diff --git a/app/user_manager.py b/app/user_manager.py index 0517b3344..7b11e381c 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -1,4 +1,3 @@ -from __future__ import annotations import json import os import re diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 47b8174f4..9bda414d1 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") -parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.") +parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.") parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.") cm_group = parser.add_mutually_exclusive_group() cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 57126fa4a..bb21eb1d1 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -1,6 +1,5 @@ """Comfy-specific type hinting""" -from __future__ import annotations from typing import Literal, TypedDict, Optional from typing_extensions import NotRequired from abc import ABC, abstractmethod diff --git a/comfy/controlnet.py b/comfy/controlnet.py index ba670b16d..6dbbaa959 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -15,13 +15,14 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ - +from __future__ import annotations import torch from enum import Enum import math import os import logging +import copy import comfy.utils import comfy.model_management import comfy.model_detection @@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet import comfy.ldm.qwen_image.controlnet import comfy.cldm.dit_embedder -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from comfy.hooks import HookGroup @@ -64,6 +65,18 @@ class StrengthType(Enum): CONSTANT = 1 LINEAR_UP = 2 +class ControlIsolation: + '''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.''' + def __init__(self, control: ControlBase): + self.control = control + self.orig_previous_controlnet = control.previous_controlnet + + def __enter__(self): + self.control.previous_controlnet = None + + def __exit__(self, *args): + self.control.previous_controlnet = self.orig_previous_controlnet + class ControlBase: def __init__(self): self.cond_hint_original = None @@ -77,7 +90,7 @@ class ControlBase: self.compression_ratio = 8 self.upscale_algorithm = 'nearest-exact' self.extra_args = {} - self.previous_controlnet = None + self.previous_controlnet: Union[ControlBase, None] = None self.extra_conds = [] self.strength_type = StrengthType.CONSTANT self.concat_mask = False @@ -85,6 +98,7 @@ class ControlBase: self.extra_concat = None self.extra_hooks: HookGroup = None self.preprocess_image = lambda a: a + self.multigpu_clones: dict[torch.device, ControlBase] = {} def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): self.cond_hint_original = cond_hint @@ -111,17 +125,38 @@ class ControlBase: def cleanup(self): if self.previous_controlnet is not None: self.previous_controlnet.cleanup() - + for device_cnet in self.multigpu_clones.values(): + with ControlIsolation(device_cnet): + device_cnet.cleanup() self.cond_hint = None self.extra_concat = None self.timestep_range = None def get_models(self): out = [] + for device_cnet in self.multigpu_clones.values(): + out += device_cnet.get_models_only_self() if self.previous_controlnet is not None: out += self.previous_controlnet.get_models() return out + def get_models_only_self(self): + 'Calls get_models, but temporarily sets previous_controlnet to None.' + with ControlIsolation(self): + return self.get_models() + + def get_instance_for_device(self, device): + 'Returns instance of this Control object intended for selected device.' + return self.multigpu_clones.get(device, self) + + def deepclone_multigpu(self, load_device, autoregister=False): + ''' + Create deep clone of Control object where model(s) is set to other devices. + + When autoregister is set to True, the deep clone is also added to multigpu_clones dict. + ''' + raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.") + def get_extra_hooks(self): out = [] if self.extra_hooks is not None: @@ -130,7 +165,7 @@ class ControlBase: out += self.previous_controlnet.get_extra_hooks() return out - def copy_to(self, c): + def copy_to(self, c: ControlBase): c.cond_hint_original = self.cond_hint_original c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range @@ -284,6 +319,14 @@ class ControlNet(ControlBase): self.copy_to(c) return c + def deepclone_multigpu(self, load_device, autoregister=False): + c = self.copy() + c.control_model = copy.deepcopy(c.control_model) + c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + if autoregister: + self.multigpu_clones[load_device] = c + return c + def get_models(self): out = super().get_models() out.append(self.control_model_wrapped) @@ -314,6 +357,10 @@ class QwenFunControlNet(ControlNet): super().pre_run(model, percent_to_timestep_function) self.set_extra_arg("base_model", model.diffusion_model) + def cleanup(self): + self.extra_args.pop("base_model", None) + super().cleanup() + def copy(self): c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c.control_model = self.control_model @@ -906,6 +953,14 @@ class T2IAdapter(ControlBase): self.copy_to(c) return c + def deepclone_multigpu(self, load_device, autoregister=False): + c = self.copy() + c.t2i_model = copy.deepcopy(c.t2i_model) + c.device = load_device + if autoregister: + self.multigpu_clones[load_device] = c + return c + def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options compression_ratio = 8 upscale_algorithm = 'nearest-exact' diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 75d459b59..12a934d71 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -799,13 +799,15 @@ class ZImagePixelSpace(ChromaRadiance): """ pass - class HiDreamO1Pixel(ChromaRadiance): """Pixel-space latent format for HiDream-O1. No VAE — model patches/unpatches raw RGB internally with patch_size=32. """ pass +class PixelDiTPixel(ChromaRadiance): + pass + class CogVideoX(LatentFormat): """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b). diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index bc36b8998..4e4819fe3 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -607,9 +607,13 @@ class HunYuanDiTPlain(nn.Module): def forward(self, x, t, context, transformer_options = {}, **kwargs): x = x.movedim(-1, -2) - if context.shape[0] >= 2: - uncond_emb, cond_emb = context.chunk(2, dim = 0) - context = torch.cat([cond_emb, uncond_emb], dim = 0) + + swap_cfg_halves = context.shape[0] >= 2 + + if swap_cfg_halves: + first_half, second_half = context.chunk(2, dim = 0) + context = torch.cat([second_half, first_half], dim = 0) + main_condition = context t = 1.0 - t @@ -657,8 +661,8 @@ class HunYuanDiTPlain(nn.Module): output = self.final_layer(combined) output = output.movedim(-2, -1) * (-1.0) - if output.shape[0] >= 2: - cond_emb, uncond_emb = output.chunk(2, dim = 0) - return torch.cat([uncond_emb, cond_emb]) - else: - return output + if swap_cfg_halves: + first_half, second_half = output.chunk(2, dim = 0) + output = torch.cat([second_half, first_half], dim = 0) + + return output diff --git a/comfy/ldm/lens/model.py b/comfy/ldm/lens/model.py new file mode 100644 index 000000000..cd5015ddc --- /dev/null +++ b/comfy/ldm/lens/model.py @@ -0,0 +1,510 @@ +"""Lens denoising transformer (DiT)""" + +from __future__ import annotations + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ldm.flux.layers +import comfy.patcher_extension +from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.modules.attention import optimized_attention + + +def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor: + return comfy.ldm.flux.layers.timestep_embedding(t, dim) + + +def _lens_position_ids( + frame: int, height: int, width: int, text_seq_len: int, + scale_rope: bool = True, device=None, +) -> torch.Tensor: + """Lens axial (frame, h, w) position ids for joint image + text sequence. + + With ``scale_rope=True`` h/w are centered around 0 (negative + positive + halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``; + caller adds a batch dim for ``EmbedND``. + """ + if scale_rope: + h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device), + torch.arange(0, height // 2, device=device)]) + w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device), + torch.arange(0, width // 2, device=device)]) + text_start = max(height // 2, width // 2) + else: + h_pos = torch.arange(height, device=device) + w_pos = torch.arange(width, device=device) + text_start = max(height, width) + + f_pos = torch.arange(frame, device=device) + img_ids = torch.zeros(frame, height, width, 3, device=device) + img_ids[..., 0] = f_pos[:, None, None] + img_ids[..., 1] = h_pos[None, :, None] + img_ids[..., 2] = w_pos[None, None, :] + img_ids = img_ids.reshape(-1, 3) + + # Text positions replicate across all 3 axes (matches original packing). + txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float() + txt_ids = txt_pos[:, None].expand(text_seq_len, 3) + + return torch.cat([img_ids, txt_ids], dim=0) + + +class _TimestepEmbedder(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None: + super().__init__() + self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device) + self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x) + x = F.silu(x) + return self.linear_2(x) + + +class LensTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None: + super().__init__() + self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations) + + def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor: + proj = _lens_time_proj(timestep, 256) + return self.timestep_embedder(proj.to(dtype=hidden_states.dtype)) + + +class GateMLP(nn.Module): + """SwiGLU MLP.""" + + def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None: + super().__init__() + self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device) + self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x))) + + +class LensJointAttention(nn.Module): + """Joint image+text attention with fused QKV per stream.""" + + def __init__( + self, + query_dim: int, + added_kv_proj_dim: int, + dim_head: int = 64, + heads: int = 8, + out_dim: Optional[int] = None, + eps: float = 1e-5, + dtype=None, + device=None, + operations=None, + ) -> None: + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.heads = self.inner_dim // dim_head + self.dim_head = dim_head + self.out_dim = out_dim if out_dim is not None else query_dim + + self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + + self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device) + self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device) + + # ModuleList([Linear, Identity]) for state-dict key compatibility. + self.to_out = nn.ModuleList([ + operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device), + nn.Identity(), + ]) + self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + transformer_options: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + bsz, seq_img, _ = hidden_states.shape + seq_txt = encoder_hidden_states.shape[1] + + # image stream + img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head) + img_q, img_k, img_v = img_qkv.unbind(dim=2) + img_q = self.norm_q(img_q) + img_k = self.norm_k(img_k) + del img_qkv + + # text stream + txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head) + txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2) + txt_q = self.norm_added_q(txt_q) + txt_k = self.norm_added_k(txt_k) + + # [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks + q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2) + del img_q, txt_q + k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2) + del img_k, txt_k + v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2) + del img_v, txt_v + + q, k = apply_rope(q, k, freqs_cis) + + if attention_mask is not None: + expected = (bsz, 1, 1, seq_img + seq_txt) + if attention_mask.shape != expected: + raise ValueError( + f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}" + ) + attention_mask = attention_mask.to(q.dtype) + + out = optimized_attention( + q, k, v, self.heads, mask=attention_mask, skip_reshape=True, + transformer_options=transformer_options, + ) + + img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :])) + txt_out = self.to_add_out(out[:, seq_img:, :]) + return img_out, txt_out + + +class LensTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + rms_norm: bool = True, + dtype=None, + device=None, + operations=None, + ) -> None: + super().__init__() + + self.attn = LensJointAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + eps=1e-5, + dtype=dtype, + device=device, + operations=operations, + ) + + if rms_norm: + NormCls = operations.RMSNorm + norm_kwargs = {} + else: + NormCls = operations.LayerNorm + norm_kwargs = {"elementwise_affine": False} + + mlp_hidden = int(dim / 3 * 8) + + # Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}. + self.img_mod = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), + ) + self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs) + self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs) + self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations) + + self.txt_mod = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), + ) + self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs) + self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs) + self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations) + + @staticmethod + def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + transformer_options: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1) + txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1) + + img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) + txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) + + img_attn, txt_attn = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + transformer_options=transformer_options, + ) + + hidden_states = hidden_states + img_gate1 * img_attn + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn + + img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) + hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2) + + txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2) + + return encoder_hidden_states, hidden_states + + +class _AdaLayerNormContinuousNoAffine(nn.Module): + """AdaLayerNormContinuous(elementwise_affine=False). + + The reference uses ``scale, shift = chunk(2)`` (scale first) — opposite + to Flux's ``LastLayer``. + """ + + def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6, + dtype=None, device=None, operations=None) -> None: + super().__init__() + self.linear = operations.Linear( + conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device + ) + self.eps = eps + self.embedding_dim = embedding_dim + + def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + emb = self.linear(F.silu(conditioning)) + scale, shift = torch.chunk(emb, 2, dim=-1) + x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class LensTransformer2DModel(nn.Module): + """Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text).""" + + def __init__( + self, + patch_size: int = 2, + in_channels: int = 128, + out_channels: Optional[int] = 32, + num_layers: int = 48, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + enc_hidden_dim: int = 2880, + axes_dims_rope: Tuple[int, int, int] = (8, 28, 28), + rms_norm: bool = True, + multi_layer_encoder_feature: bool = True, + selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23), + image_model=None, # unused; accepted for detection-side configs. + dtype=None, + device=None, + operations=None, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels if out_channels is not None else in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.multi_layer_encoder_feature = multi_layer_encoder_feature + self.selected_layer_index = list(selected_layer_index) + self.dtype = dtype + + self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) + self.time_text_embed = LensTimestepProjEmbeddings( + embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations + ) + + if self.multi_layer_encoder_feature: + self.txt_norm = nn.ModuleList( + [operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device) + for _ in self.selected_layer_index] + ) + self.txt_in = operations.Linear( + enc_hidden_dim * len(self.selected_layer_index), + self.inner_dim, bias=True, dtype=dtype, device=device, + ) + else: + self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device) + self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device) + + self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) + + self.transformer_blocks = nn.ModuleList([ + LensTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + eps=1e-6, + rms_norm=rms_norm, + dtype=dtype, device=device, operations=operations, + ) + for _ in range(num_layers) + ]) + + self.norm_out = _AdaLayerNormContinuousNoAffine( + self.inner_dim, self.inner_dim, eps=1e-6, + dtype=dtype, device=device, operations=operations, + ) + self.proj_out = operations.Linear( + self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, + dtype=dtype, device=device, + ) + + def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor: + if transformer_options is None: + transformer_options = {} + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options), + ).execute(x, timestep, context, attention_mask, transformer_options, **kwargs) + + def _forward( + self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + transformer_options: Optional[Dict[str, Any]] = None, + control: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> torch.Tensor: + """ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``.""" + if transformer_options is None: + transformer_options = {} + transformer_options = transformer_options.copy() + patches = transformer_options.get("patches", {}) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + B, C, h, w = x.shape + hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C) + + if self.multi_layer_encoder_feature: + L = len(self.selected_layer_index) + enc_dim = context.shape[-1] // L + encoder_hidden_states = list( + context.reshape(B, -1, L, enc_dim).unbind(dim=2) + ) + text_seq_len = encoder_hidden_states[0].shape[1] + else: + encoder_hidden_states = context + text_seq_len = context.shape[1] + + if attention_mask is None: + attention_mask = torch.ones( + (B, text_seq_len), dtype=torch.bool, device=x.device + ) + + img_len = h * w + joint_mask = self._build_joint_attention_mask(attention_mask, img_len) + + hidden_states = self.img_in(hidden_states) + timestep = timestep.to(hidden_states.dtype) + + if self.multi_layer_encoder_feature: + normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)] + encoder_hidden_states = torch.cat(normed, dim=-1) + else: + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if "post_input" in patches: + for p in patches["post_input"]: + out = p({ + "img": hidden_states, + "txt": encoder_hidden_states, + "transformer_options": transformer_options, + }) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + + temb = self.time_text_embed(timestep, hidden_states) + ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0) + freqs_cis = self.pos_embed(ids) + + transformer_options["total_blocks"] = len(self.transformer_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.transformer_blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = block( + hidden_states=args["img"], + encoder_hidden_states=args["txt"], + temb=args["vec"], + freqs_cis=args["pe"], + attention_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options"), + ) + return out + out = blocks_replace[("double_block", i)]( + { + "img": hidden_states, + "txt": encoder_hidden_states, + "vec": temb, + "pe": freqs_cis, + "attn_mask": joint_mask, + "transformer_options": transformer_options, + }, + {"original_block": block_wrap}, + ) + encoder_hidden_states = out["txt"] + hidden_states = out["img"] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + freqs_cis=freqs_cis, + attention_mask=joint_mask, + transformer_options=transformer_options, + ) + + if "double_block" in patches: + for p in patches["double_block"]: + out = p({ + "img": hidden_states, + "txt": encoder_hidden_states, + "x": x, + "block_index": i, + "transformer_options": transformer_options, + }) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + + if control is not None: + control_i = control.get("input") + if control_i is not None and i < len(control_i): + add = control_i[i] + if add is not None: + hidden_states[:, :add.shape[1]] += add + + hidden_states = self.norm_out(hidden_states, temb) + out = self.proj_out(hidden_states) + return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous() + + @staticmethod + def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor: + if text_mask.dtype != torch.bool: + text_mask = text_mask.bool() + bsz = text_mask.shape[0] + img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device) + joint = torch.cat([img_ones, text_mask], dim=1) + additive = torch.zeros_like(joint, dtype=torch.float32) + additive.masked_fill_(~joint, torch.finfo(torch.float32).min) + return additive[:, None, None, :] diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index bc09fb77e..ef9938465 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -767,25 +767,25 @@ class LTXAVModel(LTXVModel): # Cross-attention timesteps - compress these too av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( - timestep.max().expand_as(a_timestep_flat), + a_timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( - a_timestep.max().expand_as(timestep_flat), + timestep_flat, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( - a_timestep.max().expand_as(timestep_flat) * av_ca_factor, + a_timestep_scaled.max().expand_as(timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( - timestep.max().expand_as(a_timestep_flat) * av_ca_factor, + timestep_scaled.max().expand_as(a_timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py index b556b128f..58b67d45a 100644 --- a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -1,4 +1,3 @@ -from __future__ import annotations import torch from torch import nn from torch.nn import functional as F diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 998122c85..5975015e2 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -1,4 +1,3 @@ -from __future__ import annotations import threading import torch from torch import nn diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 9e432d5c0..d0ee97d33 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -1,5 +1,4 @@ # Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py -from __future__ import annotations from typing import List, Optional, Tuple diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 0dc8fe789..9ab3c463c 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module): Embeds scalar timesteps into vector representations. """ - def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None): + def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000): super().__init__() if output_size is None: output_size = hidden_size @@ -221,9 +221,10 @@ class TimestepEmbedder(nn.Module): operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device), ) self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period def forward(self, t, dtype, **kwargs): - t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype) + t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype) t_emb = self.mlp(t_freq) return t_emb diff --git a/comfy/ldm/moge/geometry.py b/comfy/ldm/moge/geometry.py index 7fdc97871..d1a1e445f 100644 --- a/comfy/ldm/moge/geometry.py +++ b/comfy/ldm/moge/geometry.py @@ -1,6 +1,5 @@ """Pure-torch + scipy geometry helpers for MoGe inference and mesh export.""" -from __future__ import annotations from typing import Optional, Tuple diff --git a/comfy/ldm/moge/model.py b/comfy/ldm/moge/model.py index 6876c4af2..1695626bc 100644 --- a/comfy/ldm/moge/model.py +++ b/comfy/ldm/moge/model.py @@ -4,7 +4,6 @@ V1: DINOv2 backbone + multi-output head (points, mask). V2: DINOv2 encoder + neck + per-output heads (points, mask, normal, optional metric-scale MLP). """ -from __future__ import annotations from numbers import Number from typing import Any, Dict, List, Optional, Tuple, Union diff --git a/comfy/ldm/moge/modules.py b/comfy/ldm/moge/modules.py index 235a59212..f6443d65a 100644 --- a/comfy/ldm/moge/modules.py +++ b/comfy/ldm/moge/modules.py @@ -1,6 +1,5 @@ """Building blocks for MoGe: residual conv stack, resamplers, MLP, DINOv2 encoder, v1 head.""" -from __future__ import annotations from typing import List, Optional, Sequence, Tuple, Union diff --git a/comfy/ldm/moge/panorama.py b/comfy/ldm/moge/panorama.py index de53ebe68..18d0cb665 100644 --- a/comfy/ldm/moge/panorama.py +++ b/comfy/ldm/moge/panorama.py @@ -6,7 +6,6 @@ equirect distance map via a multi-scale Poisson + gradient sparse solve. Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU). """ -from __future__ import annotations from typing import Callable, List, Optional, Tuple diff --git a/comfy/ldm/pixeldit/model.py b/comfy/ldm/pixeldit/model.py new file mode 100644 index 000000000..b044b9b29 --- /dev/null +++ b/comfy/ldm/pixeldit/model.py @@ -0,0 +1,239 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ldm.common_dit +import comfy.patcher_extension +from comfy.ldm.flux.math import apply_rope, rope +from comfy.ldm.hidream.model import FeedForwardSwiGLU +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder + +from .modules import ( + FinalLayer, + PatchTokenEmbedder, + PiTBlock, + PixelTokenEmbedder, + apply_adaln_, + precompute_freqs_cis_2d, +) + + +class MMDiTJointAttention(nn.Module): + """Joint MMDiT attention with separate Q/K/V/proj for image and text streams. + + RoPE is applied to each stream before concatenation so each stream uses its own + 2D/1D positional encoding. Concat order is [text, image] (text first). + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + + self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + + self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device) + self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device) + + def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}): + B, Nx, _ = x.shape + _, Ny, _ = y.shape + H = self.num_heads + D = self.head_dim + + qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4) + qx, kx, vx = qkv_x.unbind(0) + qx = self.q_norm_x(qx) + kx = self.k_norm_x(kx) + + qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4) + qy, ky, vy = qkv_y.unbind(0) + qy = self.q_norm_y(qy) + ky = self.k_norm_y(ky) + + qx, kx = apply_rope(qx, kx, pos_img[None, None]) + if pos_txt is not None: + qy, ky = apply_rope(qy, ky, pos_txt[None, None]) + + q_joint = torch.cat([qy, qx], dim=2) + k_joint = torch.cat([ky, kx], dim=2) + v_joint = torch.cat([vy, vx], dim=2) + + out_joint = optimized_attention( + q_joint, k_joint, v_joint, H, + mask=attn_mask, skip_reshape=True, skip_output_reshape=True, + transformer_options=transformer_options, + ) + + out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D) + out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D) + + return self.proj_x(out_x), self.proj_y(out_y) + + +class MMDiTBlockT2I(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None): + super().__init__() + self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, dtype=dtype, device=device, operations=operations) + self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations) + self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations) + self.adaLN_modulation_img = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)) + self.adaLN_modulation_txt = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)) + + def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}): + shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1) + shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1) + + x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x) + y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y) + attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options) + x = torch.addcmul(x, gate_msa_x, attn_x) + y = torch.addcmul(y, gate_msa_y, attn_y) + + x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x))) + y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y))) + return x, y + + +class PixDiT_T2I(nn.Module): + """PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint + (also runs at 512px when fed the appropriate latent size and flow_shift). + + Forward: + x: [B, 3, H, W] pixel-space input (no VAE) + timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention) + context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended) + Returns flow-matching velocity [B, 3, H, W]. + """ + def __init__( + self, + in_channels=3, + num_groups=24, + hidden_size=1536, + pixel_hidden_size=16, + pixel_attn_hidden_size=1152, + pixel_num_groups=16, + patch_depth=14, + pixel_depth=2, + patch_size=16, + txt_embed_dim=2304, + txt_max_length=300, + use_text_rope=True, + text_rope_theta=10000.0, + image_model=None, + dtype=None, + device=None, + operations=None, + pixel_mlp_chunks=2, + ): + super().__init__() + self.dtype = dtype + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.patch_depth = patch_depth + self.pixel_depth = pixel_depth + self.patch_size = patch_size + self.pixel_hidden_size = pixel_hidden_size + self.pixel_attn_hidden_size = pixel_attn_hidden_size + self.pixel_num_groups = pixel_num_groups + self.txt_embed_dim = txt_embed_dim + self.txt_max_length = txt_max_length + self.use_text_rope = use_text_rope + self.text_rope_theta = text_rope_theta + + self.pixel_embedder = PixelTokenEmbedder(self.in_channels, self.pixel_hidden_size, dtype=dtype, device=device, operations=operations) + self.s_embedder = PatchTokenEmbedder(self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, dtype=dtype, device=device, operations=operations) + self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10) + self.y_embedder = PatchTokenEmbedder(self.txt_embed_dim, self.hidden_size, bias=True, use_norm=True, dtype=dtype, device=device, operations=operations) + self.y_pos_embedding = nn.Parameter(torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device)) + + self.patch_blocks = nn.ModuleList([ + MMDiTBlockT2I(self.hidden_size, self.num_groups, + dtype=dtype, device=device, operations=operations) + for _ in range(self.patch_depth) + ]) + self.pixel_blocks = nn.ModuleList([ + PiTBlock( + self.pixel_hidden_size, + self.hidden_size, + patch_size=self.patch_size, + num_heads=self.num_groups, + attn_hidden_size=self.pixel_attn_hidden_size, + attn_num_heads=self.pixel_num_groups, + dtype=dtype, device=device, operations=operations, + mlp_chunks=pixel_mlp_chunks, + ) + for _ in range(self.pixel_depth) + ]) + + self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, dtype=dtype, device=device, operations=operations) + + def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts): + return precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width, device=device, dtype=dtype, **rope_opts) + + def _fetch_text_pos(self, length, device, dtype): + return rope(torch.arange(length, dtype=torch.float32, device=device).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0).to(dtype=dtype) + + def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options), + ).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs) + + def _pre_patch_block(self, s, i, **kwargs): + """Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate).""" + return s + + def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): + H_orig, W_orig = x.shape[2], x.shape[3] + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + B, _, H, W = x.shape + Hs = H // self.patch_size + Ws = W // self.patch_size + L = Hs * Ws + + pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {})) + x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + + t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size) + + if context is None or context.dim() != 3: + raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]") + Ltxt = min(context.shape[1], self.txt_max_length) + y = context[:, :Ltxt, :] + y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size) + y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter + + condition = F.silu(t_emb) + pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None + + s = self.s_embedder(x_patches) + for i, blk in enumerate(self.patch_blocks): + s = self._pre_patch_block(s, i, **kwargs) + s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options) + s = F.silu(t_emb + s) + + s_cond = s.view(B * L, self.hidden_size) + x_pixels = self.pixel_embedder(x, patch_size=self.patch_size) + for blk in self.pixel_blocks: + x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options) + + x_pixels = self.final_layer(x_pixels) + C_out = self.out_channels + P2 = self.patch_size * self.patch_size + x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).reshape(B, C_out * P2, L) + out = F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return out[:, :, :H_orig, :W_orig] diff --git a/comfy/ldm/pixeldit/modules.py b/comfy/ldm/pixeldit/modules.py new file mode 100644 index 000000000..4b1e538c7 --- /dev/null +++ b/comfy/ldm/pixeldit/modules.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn + +from comfy.ldm.flux.math import apply_rope, rope +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, get_1d_sincos_pos_embed_from_grid_torch + + +def apply_adaln_(x, shift, scale): + return x.addcmul_(x, scale).add_(shift) + + +def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, + ref_grid_h=None, ref_grid_w=None, + scale_x=1.0, scale_y=1.0, shift_x=0.0, shift_y=0.0, + device=None, dtype=torch.float32, **kwargs): + """2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim. + + rope_options: + scale_x / scale_y multiply the position range (RoPE extrapolation). + shift_x / shift_y offset the position origin (tiled / regional inference). + With ref_grid_h/w set, also applies NTK-aware per-axis theta scaling + (rope_mode='ntk_aware'): theta_axis = theta * (current/ref)^(dim_axis/(dim_axis-2)). + Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2]. + Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}]. + """ + dim_axis = dim // 2 + if ref_grid_h is not None and dim_axis > 2: + h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2)) + w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2)) + else: + h_ntk = w_ntk = 1.0 + + x_lin = torch.linspace(shift_x, scale * scale_x + shift_x, width, device=device) + y_lin = torch.linspace(shift_y, scale * scale_y + shift_y, height, device=device) + y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij") + x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0) + y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0) + out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2) + return out.to(dtype=dtype) + + +def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32): + """Standard 2D sin/cos absolute positional embedding (ViT-style). + + first half encodes W-coordinates, second half H. + """ + assert embed_dim % 4 == 0 + grid_h = torch.arange(height, dtype=torch.float32, device=device) + grid_w = torch.arange(width, dtype=torch.float32, device=device) + grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij") + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_x.reshape(-1), device=device) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_y.reshape(-1), device=device) + return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype) + + +class RotaryAttention(nn.Module): + """Single-stream self-attention with rotary positional encoding (used inside PiTBlock).""" + def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + + def forward(self, x, pos, mask=None, transformer_options={}): + B, N, C = x.shape + H = self.num_heads + D = self.head_dim + qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = apply_rope(self.q_norm(q), self.k_norm(k), pos[None, None]) + x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options) + return self.proj(x) + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, x): + return self.linear(self.norm(x)) + + +class PatchTokenEmbedder(nn.Module): + """Linear projection used both for patchified-image tokens and text-feature tokens.""" + def __init__(self, in_chans, embed_dim, use_norm=False, bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device) + self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) if use_norm else nn.Identity() + + def forward(self, x): + return self.norm(self.proj(x)) + + +class PixelTokenEmbedder(nn.Module): + """Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences.""" + def __init__(self, in_channels, hidden_size_output, dtype=None, device=None, operations=None): + super().__init__() + self.in_channels = in_channels + self.hidden_size_output = hidden_size_output + self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device) + + def forward(self, inputs, patch_size): + B, _, H, W = inputs.shape + Hs, Ws = H // patch_size, W // patch_size + P2 = patch_size * patch_size + x = inputs.permute(0, 2, 3, 1).contiguous() + x = self.proj(x) + pos_full = get_2d_sincos_pos_embed(self.hidden_size_output, H, W, device=x.device, dtype=x.dtype).view(H, W, self.hidden_size_output) + x = x + pos_full.unsqueeze(0) + x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output) + return x.permute(0, 1, 3, 2, 4, 5).reshape(B * Hs * Ws, P2, self.hidden_size_output) + + +class PiTBlock(nn.Module): + """Pixel-level transformer block. + + Compresses each patch's P^2 pixel tokens → 1 attention token via a linear, + runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens. + Conditioning is per-pixel adaLN from the patch-level features. + """ + def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0, + attn_hidden_size=None, attn_num_heads=None, dtype=None, device=None, operations=None, mlp_chunks=1): + super().__init__() + self.pixel_dim = pixel_hidden_size + self.context_dim = patch_hidden_size + self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size + self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads + assert self.attn_dim % self.num_heads == 0 + + p2 = patch_size * patch_size + self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device) + self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device) + + self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device) + self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, dtype=dtype, device=device, operations=operations) + self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device) + self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations) + + self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device) + self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device) + + self._rope_fn = precompute_freqs_cis_2d + self.mlp_chunks = max(1, int(mlp_chunks)) + + def _fetch_pos(self, height, width, device, dtype, **rope_opts): + return self._rope_fn(self.attn_dim // self.num_heads, height, width, device=device, dtype=dtype, **rope_opts) + + def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}): + BL, P2, _ = x.shape + Hs, Ws = image_height // patch_size, image_width // patch_size + L = Hs * Ws + B = BL // L + + # Attention path uses only msa params; compute, use, free before mlp params allocate. + msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim) + shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1) + + x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa) + x_flat = x_norm.view(BL, P2 * self.pixel_dim) + + x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim) + pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {})) + attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options) + attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim)) + attn_exp = attn_flat.view(BL, P2, self.pixel_dim) + x = torch.addcmul(x, gate_msa, attn_exp) + del msa_params, shift_msa, scale_msa, gate_msa + + mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim) + shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1) + gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP + mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp) + del mlp_params, shift_mlp, scale_mlp + + # MLP in chunks since the peak memory usage is huge here + chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks + for s in range(0, BL, chunk_size): + e = min(s + chunk_size, BL) + x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e])) + return x diff --git a/comfy/ldm/pixeldit/pid.py b/comfy/ldm/pixeldit/pid.py new file mode 100644 index 000000000..0ad4b7ce8 --- /dev/null +++ b/comfy/ldm/pixeldit/pid.py @@ -0,0 +1,226 @@ +"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent +directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I +body + LQ projection branch injected before each MMDiT patch block. +""" + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .model import PixDiT_T2I +from .modules import precompute_freqs_cis_2d + + +class SigmaAwareGatePerTokenPerDim(nn.Module): + """gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq. + + Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1. + """ + + def __init__(self, dim: int, dtype=None, device=None, operations=None): + super().__init__() + self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device) + self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device)) + + def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + content_logit = self.content_proj(torch.cat([x, lq], dim=-1)) + # log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM. + log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32) + sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1) + gate = torch.sigmoid(content_logit + sigma_offset) + return x + (gate * lq).to(x.dtype) + + +class ResBlock(nn.Module): + """Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip.""" + + def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None): + super().__init__() + self.block = nn.Sequential( + operations.GroupNorm(num_groups, channels, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device), + operations.GroupNorm(num_groups, channels, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.block(x) + + +class LQProjection2D(nn.Module): + """LQ latent -> per-block patch-aligned features for controlnet-style injection.""" + + def __init__( + self, + latent_channels: int, + hidden_dim: int = 512, + out_dim: int = 1536, + patch_size: int = 16, + sr_scale: int = 4, + latent_spatial_down_factor: int = 8, + num_res_blocks: int = 4, + num_outputs: int = 7, + interval: int = 2, + dtype=None, device=None, operations=None, + ): + super().__init__() + self.latent_channels = latent_channels + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self.patch_size = patch_size + self.sr_scale = sr_scale + self.latent_spatial_down_factor = latent_spatial_down_factor + self.num_outputs = num_outputs + self.interval = interval + + z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size + self.z_to_patch_ratio = z_to_patch_ratio + if z_to_patch_ratio >= 1: + self.latent_fold_factor = 0 + latent_proj_in_ch = latent_channels + else: + fold_factor = int(1 / z_to_patch_ratio) + assert fold_factor * z_to_patch_ratio == 1.0 + self.latent_fold_factor = fold_factor + latent_proj_in_ch = latent_channels * fold_factor * fold_factor + + layers = [ + operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device), + ] + for _ in range(num_res_blocks): + layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations)) + self.latent_proj = nn.Sequential(*layers) + + self.output_heads = nn.ModuleList( + [operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)] + ) + self.gate_modules = nn.ModuleList( + [SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations) + for _ in range(num_outputs)] + ) + + def is_gate_active(self, block_idx: int) -> bool: + return block_idx % self.interval == 0 + + def output_index(self, block_idx: int) -> int: + return block_idx // self.interval + + def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor: + return self.gate_modules[out_idx](x, lq_feature, sigma) + + def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor: + B, z_dim = lq_latent.shape[:2] + if self.z_to_patch_ratio >= 1: + if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW: + z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest") + else: + z_aligned = lq_latent + else: + f = self.latent_fold_factor + zH_expected, zW_expected = pH * f, pW * f + if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected: + lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest") + z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4) + z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW) + return self.latent_proj(z_aligned) + + def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]: + feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW) + B, C, H, W = feat.shape + tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) + return [head(tokens) for head in self.output_heads] + + +class PidNet(PixDiT_T2I): + """PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block).""" + + def __init__( + self, + lq_latent_channels: int = 16, + lq_hidden_dim: int = 512, + lq_num_res_blocks: int = 4, + lq_interval: int = 2, + sr_scale: int = 4, + latent_spatial_down_factor: int = 8, + rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64. + rope_ref_w: int = 1024, + image_model=None, + dtype=None, device=None, operations=None, + **pixdit_kwargs, + ): + super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs) + + self.rope_ref_grid_h = rope_ref_h // self.patch_size + self.rope_ref_grid_w = rope_ref_w // self.patch_size + + # Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware. + def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts): + return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts) + for blk in self.pixel_blocks: + blk._rope_fn = _pit_rope_fn + + num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval + self.lq_proj = LQProjection2D( + latent_channels=lq_latent_channels, + hidden_dim=lq_hidden_dim, + out_dim=self.hidden_size, + patch_size=self.patch_size, + sr_scale=sr_scale, + latent_spatial_down_factor=latent_spatial_down_factor, + num_res_blocks=lq_num_res_blocks, + num_outputs=num_lq_outputs, + interval=lq_interval, + dtype=dtype, + device=device, + operations=operations, + ) + + def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts): + return precompute_freqs_cis_2d( + self.hidden_size // self.num_groups, + height, width, + ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, + device=device, dtype=dtype, **rope_opts, + ) + + def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs): + if not self.lq_proj.is_gate_active(i): + return s + out_idx = self.lq_proj.output_index(i) + if out_idx >= len(pid_lq_features): + return s + return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx) + + def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs): + if lq_latent is None: + raise ValueError("PidNet requires lq_latent — attach via PiDConditioning") + expected_c = self.lq_proj.latent_channels + if lq_latent.shape[1] != expected_c: + raise ValueError( + f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. " + f"Flux1/SD3 = 16 channels, Flux2 = 128 channels." + ) + B = x.shape[0] + Hs = x.shape[2] // self.patch_size + Ws = x.shape[3] // self.patch_size + + degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1) + if degrade_sigma.numel() == 1 and B > 1: + degrade_sigma = degrade_sigma.expand(B).contiguous() + + lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws) + + return super()._forward( + x, timesteps, + context=context, attention_mask=attention_mask, + transformer_options=transformer_options, + pid_lq_features=lq_features, + pid_degrade_sigma=degrade_sigma, + **kwargs, + ) diff --git a/comfy/lora.py b/comfy/lora.py index c0e8b865c..4e0ea29e0 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -16,7 +16,6 @@ along with this program. If not, see . """ -from __future__ import annotations import comfy.memory_management import comfy.utils import comfy.model_management diff --git a/comfy/memory_management.py b/comfy/memory_management.py index c43f0c4a2..962addb27 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,6 +1,5 @@ import math import ctypes -import threading import dataclasses import torch from typing import NamedTuple @@ -10,7 +9,7 @@ from comfy.quant_ops import QuantizedTensor class TensorFileSlice(NamedTuple): file_ref: object - thread_id: int + lock: object offset: int size: int @@ -43,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N file_obj = info.file_ref if (destination.device.type != "cpu" or file_obj is None - or threading.get_ident() != info.thread_id or destination.numel() * destination.element_size() < info.size or tensor.numel() * tensor.element_size() != info.size or tensor.storage_offset() != 0 @@ -57,27 +55,29 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N if hostbuf is not None: stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 device_ptr = destination2.data_ptr() if destination2 is not None else 0 - hostbuf.read_file_slice(file_obj, info.offset, info.size, - offset=destination.data_ptr() - hostbuf.get_raw_address(), - stream=stream_ptr, - device_ptr=device_ptr, - device=None if destination2 is None else destination2.device.index) + with info.lock: + hostbuf.read_file_slice(file_obj, info.offset, info.size, + offset=destination.data_ptr() - hostbuf.get_raw_address(), + stream=stream_ptr, + device_ptr=device_ptr, + device=None if destination2 is None else destination2.device.index) return True buf_type = ctypes.c_ubyte * info.size view = memoryview(buf_type.from_address(destination.data_ptr())) try: - file_obj.seek(info.offset) - done = 0 - while done < info.size: - try: - n = file_obj.readinto(view[done:]) - except OSError: - return False - if n <= 0: - return False - done += n + with info.lock: + file_obj.seek(info.offset) + done = 0 + while done < info.size: + try: + n = file_obj.readinto(view[done:]) + except OSError: + return False + if n <= 0: + return False + done += n return True finally: view.release() diff --git a/comfy/model_base.py b/comfy/model_base.py index d81f13c69..e55808633 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -35,6 +35,7 @@ import comfy.ldm.hydit.models import comfy.ldm.audio.dit import comfy.ldm.audio.embedders import comfy.ldm.flux.model +import comfy.ldm.lens.model import comfy.ldm.lightricks.model import comfy.ldm.hunyuan_video.model import comfy.ldm.cosmos.model @@ -48,6 +49,8 @@ import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.chroma_radiance.model +import comfy.ldm.pixeldit.model +import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model @@ -1058,6 +1061,27 @@ class Flux2(Flux): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + +class Lens(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__( + model_config, model_type, device=device, + unet_model=comfy.ldm.lens.model.LensTransformer2DModel, + ) + + def encode_adm(self, **kwargs): + return None # Lens has no pooled/ADM conditioning. + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + return out + class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint) @@ -1375,6 +1399,36 @@ class ZImagePixelSpace(Lumina2): BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) self.memory_usage_factor_conds = ("ref_latents",) + +class PixelDiTT2I(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out["attention_mask"] = comfy.conds.CONDRegular(attention_mask) + return out + + +class PiD(PixelDiTT2I): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + BaseModel.__init__(self, model_config, model_type, device=device, + unet_model=comfy.ldm.pixeldit.pid.PidNet) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + lq_latent = kwargs.get("lq_latent", None) + if lq_latent is not None: + out["lq_latent"] = comfy.conds.CONDRegular(lq_latent) + degrade_sigma = kwargs.get("degrade_sigma", None) + if degrade_sigma is not None: + out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma) + return out + + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 70b4df8b3..f0db7d388 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -463,6 +463,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" return dit_config + # PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I. + _lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix) + if _lq_w_key in state_dict_keys: + in_ch = int(state_dict[_lq_w_key].shape[1]) + _gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix) + num_gates = len({k[len(_gate_prefix):].split('.')[0] + for k in state_dict_keys if k.startswith(_gate_prefix)}) + dit_config = {"image_model": "pid", + "lq_latent_channels": in_ch, + "latent_spatial_down_factor": 16 if in_ch >= 64 else 8} + if num_gates > 0: + dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates + return dit_config + + if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I + return {"image_model": "pixeldit_t2i"} + if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 dit_config = {} dit_config["image_model"] = "lumina2" @@ -755,6 +772,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["timestep_scale"] = 1000.0 return dit_config + if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \ + and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens + img_in_w = state_dict['{}img_in.weight'.format(key_prefix)] + proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)] + multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys + if multi_layer: + enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0] + # Indices are TE-side; the DiT just consumes L layers in order. + selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.'))) + else: + enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0] + selected_layer_index = (0,) + + return { + "image_model": "lens", + "in_channels": img_in_w.shape[1], + "out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default) + "num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'), + "num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default + "enc_hidden_dim": enc_hidden_dim, + "multi_layer_encoder_feature": multi_layer, + "selected_layer_index": selected_layer_index, + } + if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image dit_config = {} dit_config["image_model"] = "qwen_image" diff --git a/comfy/model_management.py b/comfy/model_management.py index cd8772d3a..b01c4d7fa 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -15,6 +15,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ +from __future__ import annotations import psutil import logging @@ -27,13 +28,18 @@ import platform import weakref import gc import os -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops import comfy_aimdo.host_buffer import comfy_aimdo.vram_buffer +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + + class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -204,6 +210,107 @@ def get_torch_device(): else: return torch.device(torch.cuda.current_device()) +def get_all_torch_devices(exclude_current=False): + global cpu_state + devices = [] + if cpu_state == CPUState.GPU: + # NVIDIA + AMD/ROCm both expose their GPUs through torch.cuda.*; + # without the AMD arm, single-GPU ROCm users get an empty list + # which silently turns unload_all_models() into a no-op. + if is_nvidia() or is_amd(): + for i in range(torch.cuda.device_count()): + devices.append(torch.device("cuda", i)) + elif is_intel_xpu(): + for i in range(torch.xpu.device_count()): + devices.append(torch.device("xpu", i)) + elif is_ascend_npu(): + for i in range(torch.npu.device_count()): + devices.append(torch.device("npu", i)) + elif is_mlu(): + for i in range(torch.mlu.device_count()): + devices.append(torch.device("mlu", i)) + else: + # Fallback for unhandled GPU backends (e.g. DirectML): at least + # report the current device so callers like unload_all_models() + # do not silently no-op. + devices.append(get_torch_device()) + else: + devices.append(get_torch_device()) + if exclude_current: + current = get_torch_device() + if current in devices: + devices.remove(current) + return devices + +def get_gpu_device_options(): + """Return list of device option strings for node widgets. + + Always includes "default" and "cpu". When multiple GPUs are present, + adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels). + """ + options = ["default", "cpu"] + devices = get_all_torch_devices() + if len(devices) > 1: + for i in range(len(devices)): + options.append(f"gpu:{i}") + return options + +def get_gpu_device_options_no_cpu(): + """Variant of get_gpu_device_options that omits "cpu". + + Intended for components like the VAE selector where running on CPU + is impractical and should not be offered as a choice. + """ + return [o for o in get_gpu_device_options() if o != "cpu"] + +def resolve_gpu_device_option(option: str): + """Resolve a device option string to a torch.device. + + Returns None for "default" (let the caller use its normal default). + Returns torch.device("cpu") for "cpu". + For "gpu:N", returns the Nth torch device. Returns None if the + index is out of range, the option string is malformed, or + unrecognized (callers are expected to log their own context-rich + message before falling back to the default device). + """ + if option is None or option == "default": + return None + if option == "cpu": + return torch.device("cpu") + if option.startswith("gpu:"): + try: + idx = int(option[4:]) + except ValueError: + return None + devices = get_all_torch_devices() + if 0 <= idx < len(devices): + return devices[idx] + return None + +@contextmanager +def cuda_device_context(device): + """Context manager that sets torch.cuda.current_device to match *device*. + + Used when running operations on a non-default CUDA device so that custom + CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct + device index. The previous device is restored on exit. + + No-op when *device* is not CUDA, has no explicit index, or already matches + the current device. + """ + prev = None + if device.type == "cuda" and device.index is not None: + prev = torch.cuda.current_device() + if prev != device.index: + torch.cuda.set_device(device) + else: + prev = None + try: + yield + finally: + if prev is not None: + torch.cuda.set_device(prev) + def get_total_memory(dev=None, torch_total_too=False): global directml_enabled if dev is None: @@ -492,9 +599,13 @@ try: logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) except: logging.warning("Could not pick default device.") +try: + for device in get_all_torch_devices(exclude_current=True): + logging.info("Device: {}".format(get_torch_device_name(device))) +except: + pass - -current_loaded_models = [] +current_loaded_models: list[LoadedModel] = [] DIRTY_MMAPS = set() @@ -554,7 +665,7 @@ def ensure_pin_registerable(size, evict_active=False): return shortfall <= REGISTERABLE_PIN_HYSTERESIS class LoadedModel: - def __init__(self, model): + def __init__(self, model: ModelPatcher): self._set_model(model) self.device = model.load_device self.real_model = None @@ -562,7 +673,7 @@ class LoadedModel: self.model_finalizer = None self._patcher_finalizer = None - def _set_model(self, model): + def _set_model(self, model: ModelPatcher): self._model = weakref.ref(model) if model.parent is not None: self._parent_model = weakref.ref(model.parent) @@ -573,6 +684,7 @@ class LoadedModel: model = self._parent_model() if model is not None: self._set_model(model) + self.device = model.load_device @property def model(self): @@ -1848,7 +1960,34 @@ def soft_empty_cache(force=False): torch.cuda.ipc_collect() def unload_all_models(): - free_memory(1e30, get_torch_device()) + for device in get_all_torch_devices(): + free_memory(1e30, device) + +def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False): + 'Unload only model and its clones - primarily for multigpu cloning purposes.' + initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy() + additional_models = [] + if unload_additional_models: + additional_models = model.get_nested_additional_models() + keep_loaded = [] + for loaded_model in initial_keep_loaded: + if loaded_model.model is not None: + if model.clone_base_uuid == loaded_model.model.clone_base_uuid: + continue + # check additional models if they are a match + skip = False + for add_model in additional_models: + if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid: + skip = True + break + if skip: + continue + keep_loaded.append(loaded_model) + if not all_devices: + free_memory(1e30, get_torch_device(), keep_loaded) + else: + for device in get_all_torch_devices(): + free_memory(1e30, device, keep_loaded) def debug_memory_summary(): if is_amd() or is_nvidia(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b44b99e4a..00a15fa63 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -78,12 +78,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_ def create_model_options_clone(orig_model_options: dict): return comfy.patcher_extension.copy_nested_dicts(orig_model_options) -def create_hook_patches_clone(orig_hook_patches): +def create_hook_patches_clone(orig_hook_patches, copy_tuples=False): new_hook_patches = {} for hook_ref in orig_hook_patches: new_hook_patches[hook_ref] = {} for k in orig_hook_patches[hook_ref]: new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:] + if copy_tuples: + for i in range(len(new_hook_patches[hook_ref][k])): + new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i]) return new_hook_patches def wipe_lowvram_weight(m): @@ -329,7 +332,10 @@ class ModelPatcher: self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed - self.cached_patcher_init: tuple[Callable, tuple] | None = None + self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None + self.is_multigpu_base_clone = False + self.clone_base_uuid = uuid.uuid4() + if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -366,7 +372,8 @@ class ModelPatcher: #than pays for CFG. So return everything both torch and Aimdo could give us aimdo_mem = 0 if comfy.memory_management.aimdo_enabled: - aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device) return comfy.model_management.get_free_memory(device) + aimdo_mem def get_clone_model_override(self): @@ -380,6 +387,8 @@ class ModelPatcher: if self.cached_patcher_init is None: raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.") temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) + if len(self.cached_patcher_init) > 2: + temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]] model_override = temp_model_patcher.get_clone_model_override() if model_override is None: model_override = self.get_clone_model_override() @@ -438,19 +447,113 @@ class ModelPatcher: n.hook_mode = self.hook_mode n.cached_patcher_init = self.cached_patcher_init + n.is_multigpu_base_clone = self.is_multigpu_base_clone + n.clone_base_uuid = self.clone_base_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n + def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None): + logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.") + if self.cached_patcher_init is None: + raise RuntimeError( + f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: " + "the loader that produced this model does not support multigpu " + "(cached_patcher_init is not initialized). Use a core loader " + "(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), " + "or have the custom loader register a cached_patcher_init factory." + ) + comfy.model_management.unload_model_and_clones(self) + # Produce a freshly-loaded patcher from the loader factory so the multigpu + # clone owns its own untainted model weights (rather than relying on + # copy.deepcopy of an already-patched/already-loaded module). + temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1]) + if len(self.cached_patcher_init) > 2: + temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]] + # Override clone()'s normal "share self.model + share backup containers" with + # the pristine model from temp_model_patcher plus empty backup containers -- + # the fresh model has no patches applied, so any deepcopy of self's stale + # backup/object_patches_backup/pinned would just propagate dead state that + # no longer corresponds to anything in n.model. + model_override = (temp_model_patcher.model, ({}, {}, {}, set())) + n = self.clone(model_override=model_override) + # clone() copies hook_backup by reference from self; reset since model is pristine. + n.hook_backup = {} + # set load device, if present + if new_load_device is not None: + n.load_device = new_load_device + # Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins) + # has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's + # __init__ only registered its own (default) load_device. + if hasattr(n, "register_load_device"): + n.register_load_device(n.load_device) + # multigpu clone should not have multigpu additional_models entry + n.remove_additional_models("multigpu") + # multigpu_clone all stored additional_models; make sure circular references are properly handled + if models_cache is None: + models_cache = {} + for key, model_list in n.additional_models.items(): + for i in range(len(model_list)): + add_model = n.additional_models[key][i] + if add_model.clone_base_uuid not in models_cache: + models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache) + n.additional_models[key][i] = models_cache[add_model.clone_base_uuid] + for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU): + callback(self, n) + return n + + def match_multigpu_clones(self): + multigpu_models = self.get_additional_models_with_key("multigpu") + if len(multigpu_models) > 0: + new_multigpu_models = [] + for mm in multigpu_models: + # clone main model, but bring over relevant props from existing multigpu clone + n = self.clone() + n.load_device = mm.load_device + n.backup = mm.backup + n.object_patches_backup = mm.object_patches_backup + n.hook_backup = mm.hook_backup + n.model = mm.model + n.is_multigpu_base_clone = mm.is_multigpu_base_clone + n.remove_additional_models("multigpu") + orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models) + n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models) + # figure out which additional models are not present in multigpu clone + models_cache = {} + for mm_add_model in mm.get_additional_models(): + models_cache[mm_add_model.clone_base_uuid] = mm_add_model + remove_models_uuids = set(list(models_cache.keys())) + for key, model_list in orig_additional_models.items(): + for orig_add_model in model_list: + if orig_add_model.clone_base_uuid not in models_cache: + models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache) + existing_list = n.get_additional_models_with_key(key) + existing_list.append(models_cache[orig_add_model.clone_base_uuid]) + n.set_additional_models(key, existing_list) + if orig_add_model.clone_base_uuid in remove_models_uuids: + remove_models_uuids.remove(orig_add_model.clone_base_uuid) + # remove duplicate additional models + for key, model_list in n.additional_models.items(): + new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids] + n.set_additional_models(key, new_model_list) + for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES): + callback(self, n) + new_multigpu_models.append(n) + self.set_additional_models("multigpu", new_multigpu_models) + def is_clone(self, other): if hasattr(other, 'model') and self.model is other.model: return True return False - def clone_has_same_weights(self, clone: 'ModelPatcher'): - if not self.is_clone(clone): - return False + def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False): + if allow_multigpu: + if self.clone_base_uuid != clone.clone_base_uuid: + return False + else: + if not self.is_clone(clone): + return False if self.current_hooks != clone.current_hooks: return False @@ -1232,7 +1335,7 @@ class ModelPatcher: return self.additional_models.get(key, []) def get_additional_models(self): - all_models = [] + all_models: list[ModelPatcher] = [] for models in self.additional_models.values(): all_models.extend(models) return all_models @@ -1286,9 +1389,18 @@ class ModelPatcher: for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): callback(self) - def prepare_state(self, timestep): + def prepare_state(self, timestep, model_options): + ignore_multigpu = model_options.get("ignore_multigpu", False) for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): - callback(self, timestep) + callback(self, timestep, model_options) + if not ignore_multigpu and "multigpu_clones" in model_options: + model_options["ignore_multigpu"] = True + try: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p.prepare_state(timestep, model_options) + finally: + model_options.pop("ignore_multigpu", None) def restore_hook_patches(self): if self.hook_patches_backup is not None: @@ -1301,12 +1413,18 @@ class ModelPatcher: def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): curr_t = t[0] reset_current_hooks = False + multigpu_kf_changed_cache = None transformer_options = model_options.get("transformer_options", {}) for hook in hook_group.hooks: changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # this will cause the weights to be recalculated when sampling if changed: + # cache changed for multigpu usage + if "multigpu_clones" in model_options: + if multigpu_kf_changed_cache is None: + multigpu_kf_changed_cache = [] + multigpu_kf_changed_cache.append(hook) # reset current_hooks if contains hook that changed if self.current_hooks is not None: for current_hook in self.current_hooks.hooks: @@ -1318,6 +1436,28 @@ class ModelPatcher: self.cached_hook_patches.pop(cached_group) if reset_current_hooks: self.patch_hooks(None) + if "multigpu_clones" in model_options: + for p in model_options["multigpu_clones"].values(): + p: ModelPatcher + p._handle_changed_hook_keyframes(multigpu_kf_changed_cache) + + def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]): + 'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.' + if kf_changed_cache is None: + return + reset_current_hooks = False + # reset current_hooks if contains hook that changed + for hook in kf_changed_cache: + if self.current_hooks is not None: + for current_hook in self.current_hooks.hooks: + if current_hook == hook: + reset_current_hooks = True + break + for cached_group in list(self.cached_hook_patches.keys()): + if cached_group.contains(hook): + self.cached_hook_patches.pop(cached_group) + if reset_current_hooks: + self.patch_hooks(None) def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, registered: comfy.hooks.HookGroup = None): @@ -1566,16 +1706,27 @@ class ModelPatcherDynamic(ModelPatcher): self.model.dynamic_vbars = {} if not hasattr(self.model, "dynamic_pins"): self.model.dynamic_pins = {} - if self.load_device not in self.model.dynamic_pins: - self.model.dynamic_pins[self.load_device] = { + self.register_load_device(self.load_device) + self.non_dynamic_delegate_model = None + assert load_device is not None + + def register_load_device(self, device): + """Ensure dynamic_pins has an entry for *device*. + + Called from __init__ and also from any code that retargets an + already-constructed patcher to a new load_device (e.g. the + Select{Model,CLIP,VAE}Device selector nodes); without this entry + partially_unload_ram() raises KeyError when it tries to read the + per-device pin state. + """ + if device not in self.model.dynamic_pins: + self.model.dynamic_pins[device] = { "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), "hostbufs_initialized": False, "failed": False, "active": False, } - self.non_dynamic_delegate_model = None - assert load_device is not None def is_dynamic(self): return True diff --git a/comfy/multigpu.py b/comfy/multigpu.py new file mode 100644 index 000000000..e7f5b3d6f --- /dev/null +++ b/comfy/multigpu.py @@ -0,0 +1,248 @@ +from __future__ import annotations +import queue +import threading +import torch +import logging + +from collections import namedtuple +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher +import comfy.utils +import comfy.patcher_extension +import comfy.model_management + + +class MultiGPUThreadPool: + """Persistent thread pool for multi-GPU work distribution. + + Maintains one worker thread per extra GPU device. Each thread calls + torch.cuda.set_device() once at startup so that compiled kernel caches + (inductor/triton) stay warm across diffusion steps. + """ + + def __init__(self, devices: list[torch.device]): + self._workers: list[threading.Thread] = [] + self._work_queues: dict[torch.device, queue.Queue] = {} + self._result_queues: dict[torch.device, queue.Queue] = {} + + for device in devices: + wq = queue.Queue() + rq = queue.Queue() + self._work_queues[device] = wq + self._result_queues[device] = rq + t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True) + t.start() + self._workers.append(t) + + def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue): + try: + torch.cuda.set_device(device) + except Exception as e: + logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}") + while True: + item = work_q.get() + if item is None: + return + result_q.put((None, e)) + return + while True: + item = work_q.get() + if item is None: + break + fn, args, kwargs = item + try: + result = fn(*args, **kwargs) + result_q.put((result, None)) + except Exception as e: + result_q.put((None, e)) + + def submit(self, device: torch.device, fn, *args, **kwargs): + self._work_queues[device].put((fn, args, kwargs)) + + def get_result(self, device: torch.device): + return self._result_queues[device].get() + + @property + def devices(self) -> list[torch.device]: + return list(self._work_queues.keys()) + + def shutdown(self): + for wq in self._work_queues.values(): + wq.put(None) # sentinel + for t in self._workers: + t.join(timeout=5.0) + + +class GPUOptions: + def __init__(self, device_index: int, relative_speed: float): + self.device_index = device_index + self.relative_speed = relative_speed + + def clone(self): + return GPUOptions(self.device_index, self.relative_speed) + + def create_dict(self): + return { + "relative_speed": self.relative_speed + } + +class GPUOptionsGroup: + def __init__(self): + self.options: dict[int, GPUOptions] = {} + + def add(self, info: GPUOptions): + self.options[info.device_index] = info + + def clone(self): + c = GPUOptionsGroup() + for opt in self.options.values(): + c.add(opt) + return c + + def register(self, model: ModelPatcher): + opts_dict = {} + # get devices that are valid for this model + devices: list[torch.device] = [model.load_device] + for extra_model in model.get_additional_models_with_key("multigpu"): + extra_model: ModelPatcher + devices.append(extra_model.load_device) + # create dictionary with actual device mapped to its GPUOptions + device_opts_list: list[GPUOptions] = [] + for device in devices: + device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0)) + opts_dict[device] = device_opts.create_dict() + device_opts_list.append(device_opts) + # make relative_speed relative to 1.0 + min_speed = min([x.relative_speed for x in device_opts_list]) + for value in opts_dict.values(): + value['relative_speed'] /= min_speed + model.model_options['multigpu_options'] = opts_dict + + +def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False): + 'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.' + model = model.clone() + # check if multigpu is already prepared - get the load devices from them if possible to exclude + skip_devices = set() + multigpu_models = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) > 0: + for mm in multigpu_models: + skip_devices.add(mm.load_device) + skip_devices = list(skip_devices) + + # Exclude the primary model's actual device, not the global current device: + # after SelectModelDevice(gpu:N) the primary may not live on the process's + # current CUDA device, and excluding the wrong device picks bad extras. + all_devices = comfy.model_management.get_all_torch_devices(exclude_current=False) + full_extra_devices = [d for d in all_devices if d != model.load_device] + limit_extra_devices = full_extra_devices[:max_gpus-1] + extra_devices = limit_extra_devices.copy() + # exclude skipped devices + for skip in skip_devices: + if skip in extra_devices: + extra_devices.remove(skip) + # create new deepclones + if len(extra_devices) > 0: + for device in extra_devices: + device_patcher = None + if reuse_loaded: + # Only reuse a previously-loaded MultiGPU clone. A SelectModelDevice + # patcher on the same device shares clone_base_uuid but has + # is_multigpu_base_clone=False, which would later be filtered out by + # prepare_model_patcher_multigpu_clones() and silently shrink the + # work split back to one GPU. + loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models() + for lm in loaded_models: + if lm.model is None: + continue + if lm.load_device != device: + continue + if lm.clone_base_uuid != model.clone_base_uuid: + continue + if not getattr(lm, "is_multigpu_base_clone", False): + continue + device_patcher = lm.clone() + logging.info(f"Reusing loaded multigpu deepclone of {device_patcher.model.__class__.__name__} for {device}") + break + if device_patcher is None: + device_patcher = model.deepclone_multigpu(new_load_device=device) + # Always flag the clone; whether reused or freshly deepcloned, it must + # advertise itself as a MultiGPU base clone so the cond scheduler picks + # it up in prepare_model_patcher_multigpu_clones(). + device_patcher.is_multigpu_base_clone = True + multigpu_models = model.get_additional_models_with_key("multigpu") + multigpu_models.append(device_patcher) + model.set_additional_models("multigpu", multigpu_models) + model.match_multigpu_clones() + if gpu_options is None: + gpu_options = GPUOptionsGroup() + gpu_options.register(model) + else: + logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.") + # only keep model clones that don't go 'past' the intended max_gpu count; + # this prunes any inherited multigpu clones whose load_device is no longer allowed + # when max_gpus is lowered between runs. + allowed_devices = set(limit_extra_devices) + allowed_devices.add(model.load_device) + multigpu_models = model.get_additional_models_with_key("multigpu") + new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices] + if len(new_multigpu_models) != len(multigpu_models): + model.set_additional_models("multigpu", new_multigpu_models) + model.match_multigpu_clones() + return model + + +LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) +def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): + 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' + opts_dict = model_options['multigpu_options'] + devices = list(model_options['multigpu_clones'].keys()) + speed_per_device = [] + work_per_device = [] + # get sum of each device's relative_speed + total_speed = 0.0 + for opts in opts_dict.values(): + total_speed += opts['relative_speed'] + # get relative work for each device; + # obtained by w = (W*r)/R + for device in devices: + relative_speed = opts_dict[device]['relative_speed'] + relative_work = (total_work*relative_speed) / total_speed + speed_per_device.append(relative_speed) + work_per_device.append(relative_work) + # relative work must be expressed in whole numbers, but likely is a decimal; + # perform rounding while maintaining total sum equal to total work (sum of relative works) + work_per_device = round_preserved(work_per_device) + dict_work_per_device = {} + for device, relative_work in zip(devices, work_per_device): + dict_work_per_device[device] = relative_work + if not return_idle_time: + return LoadBalance(dict_work_per_device, None) + # divide relative work by relative speed to get estimated completion time of said work by each device; + # time here is relative and does not correspond to real-world units + completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)] + # calculate relative time spent by the devices waiting on each other after their work is completed + idle_time = abs(min(completion_time) - max(completion_time)) + # if need to compare work idle time, need to normalize to a common total work + if work_normalized: + idle_time *= (work_normalized/total_work) + + return LoadBalance(dict_work_per_device, idle_time) + +def round_preserved(values: list[float]): + 'Round all values in a list, preserving the combined sum of values.' + # get floor of values; casting to int does it too + floored = [int(x) for x in values] + total_floored = sum(floored) + # get remainder to distribute + remainder = round(sum(values)) - total_floored + # pair values with fractional portions + fractional = [(i, x-floored[i]) for i, x in enumerate(values)] + # sort by fractional part in descending order + fractional.sort(key=lambda x: x[1], reverse=True) + # distribute the remainder + for i in range(remainder): + index = fractional[i][0] + floored[index] += 1 + return floored diff --git a/comfy/ops.py b/comfy/ops.py index 9bcd6c900..56445be8d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -18,6 +18,7 @@ import torch import logging +import contextlib import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float @@ -1047,6 +1048,144 @@ class QuantLinearFunc(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None +# Quantized-weight module helpers + +def _quantized_apply(module, fn, recurse=True): + """Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights.""" + if recurse: + for child in module.children(): + child._apply(fn) + for key, param in module._parameters.items(): + if param is None: + continue + p = fn(param) + if (not torch.is_inference_mode_enabled()) and p.is_inference(): + p = p.clone() + module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) + for key, buf in module._buffers.items(): + if buf is not None: + module._buffers[key] = fn(buf) + return module + + +def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs, load_extra_params=False): + """Shared _load_from_state_dict body for quantized-weight modules. + + Pops weight (+ scales, +/- extras), populates module.weight as a Parameter + or Parameter-wrapped QuantizedTensor, then calls super_load and strips + consumed keys from missing_keys. Reads compute_dtype from factory_kwargs + and disabled formats from module._disabled_formats. + """ + device = module.factory_kwargs["device"] + compute_dtype = module.factory_kwargs["dtype"] + disabled_formats = module._disabled_formats + layer_name = prefix.rstrip('.') + + weight = state_dict.pop(f"{prefix}weight", None) + if weight is None: + logging.warning(f"Missing weight for layer {layer_name}") + module.weight = None + return + manually_loaded_keys = [f"{prefix}weight"] + + def pop_scale(name, dtype=None): + key = f"{prefix}{name}" + v = state_dict.pop(key, None) + if v is not None: + v = v.to(device=device) + if dtype is not None: + v = v.view(dtype=dtype) + manually_loaded_keys.append(key) + return v + + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + if layer_conf is None: + module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False) + else: + module.quant_format = layer_conf.get("format", None) + module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) + if not module._full_precision_mm: + module._full_precision_mm = module._full_precision_mm_config + if module.quant_format in disabled_formats: + module._full_precision_mm = True + if module.quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") + + qconfig = QUANT_ALGOS[module.quant_format] + module.layout_type = qconfig["comfy_tensor_layout"] + layout_cls = get_layout_class(module.layout_type) + + # Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8. + if module.quant_format in ("float8_e4m3fn", "float8_e5m2"): + scales = {"scale": pop_scale("weight_scale")} + elif module.quant_format == "mxfp8": + bs = pop_scale("weight_scale", torch.float8_e8m0fnu) + if bs is None: + raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") + scales = {"scale": bs} + elif module.quant_format == "nvfp4": + ts = pop_scale("weight_scale_2") + bs = pop_scale("weight_scale", torch.float8_e4m3fn) + if ts is None or bs is None: + raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") + scales = {"scale": ts, "block_scale": bs} + else: + raise ValueError(f"Unsupported quantization format: {module.quant_format}") + + params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape) + module.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params), + requires_grad=False, + ) + + if load_extra_params: + for param_name in qconfig["parameters"]: + if param_name in {"weight_scale", "weight_scale_2"}: + continue + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) + + super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) + + +def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()): + """Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON; + extra_quant_params names attributes written as additional top-level keys.""" + if not hasattr(module, 'weight'): + logging.warning(f"Warning: state dict on uninitialized op {prefix}") + return sd + bias = getattr(module, 'bias', None) + if bias is not None: + sd[f"{prefix}bias"] = bias + if module.weight is None: + return sd + if isinstance(module.weight, QuantizedTensor): + sd.update(module.weight.state_dict(f"{prefix}weight")) + quant_conf = {"format": module.quant_format} + if getattr(module, '_full_precision_mm_config', False): + quant_conf["full_precision_matrix_mult"] = True + if extra_quant_conf: + quant_conf.update(extra_quant_conf) + sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8) + for name in extra_quant_params: + value = getattr(module, name, None) + if value is not None: + sd[f"{prefix}{name}"] = value + else: + sd[f"{prefix}weight"] = module.weight + return sd + def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): @@ -1056,21 +1195,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec _disabled = disabled class Linear(torch.nn.Module, CastWeightBiasOp): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: + _disabled_formats = disabled + + def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): super().__init__() self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} - # self.factory_kwargs = {"device": device, "dtype": dtype} self.in_features = in_features self.out_features = out_features + self._orig_shape = (out_features, in_features) if bias: self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) else: @@ -1083,151 +1217,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def reset_parameters(self): return None - def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None): - key = f"{prefix}{param_name}" - value = state_dict.pop(key, None) - if value is not None: - value = value.to(device=device) - if dtype is not None: - value = value.view(dtype=dtype) - manually_loaded_keys.append(key) - return value - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs): - - device = self.factory_kwargs["device"] - layer_name = prefix.rstrip('.') - weight_key = f"{prefix}weight" - weight = state_dict.pop(weight_key, None) - if weight is None: - logging.warning(f"Missing weight for layer {layer_name}") - self.weight = None - return - - manually_loaded_keys = [weight_key] - - layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) - if layer_conf is not None: - layer_conf = json.loads(layer_conf.numpy().tobytes()) - - if layer_conf is None: - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) - else: - self.quant_format = layer_conf.get("format", None) - self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) - if not self._full_precision_mm: - self._full_precision_mm = self._full_precision_mm_config - - if self.quant_format in MixedPrecisionOps._disabled: - self._full_precision_mm = True - - if self.quant_format is None: - raise ValueError(f"Unknown quantization format for layer {layer_name}") - - qconfig = QUANT_ALGOS[self.quant_format] - self.layout_type = qconfig["comfy_tensor_layout"] - layout_cls = get_layout_class(self.layout_type) - - # Load format-specific parameters - if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]: - # FP8: single tensor scale - scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - - params = layout_cls.Params( - scale=scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=(self.out_features, self.in_features), - ) - - elif self.quant_format == "mxfp8": - # MXFP8: E8M0 block scales stored as uint8 in safetensors - block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, - dtype=torch.uint8) - - if block_scale is None: - raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") - - block_scale = block_scale.view(torch.float8_e8m0fnu) - - params = layout_cls.Params( - scale=block_scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=(self.out_features, self.in_features), - ) - - elif self.quant_format == "nvfp4": - # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) - tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) - block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, - dtype=torch.float8_e4m3fn) - - if tensor_scale is None or block_scale is None: - raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") - - params = layout_cls.Params( - scale=tensor_scale, - block_scale=block_scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=(self.out_features, self.in_features), - ) - else: - raise ValueError(f"Unsupported quantization format: {self.quant_format}") - - self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), - requires_grad=False - ) - - for param_name in qconfig["parameters"]: - if param_name in {"weight_scale", "weight_scale_2"}: - continue # Already handled above - - param_key = f"{prefix}{param_name}" - _v = state_dict.pop(param_key, None) - if _v is None: - continue - self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) - manually_loaded_keys.append(param_key) - - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - for key in manually_loaded_keys: - if key in missing_keys: - missing_keys.remove(key) + def _load_from_state_dict(self, *args): + _load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True) def state_dict(self, *args, destination=None, prefix="", **kwargs): - if destination is not None: - sd = destination - else: - sd = {} - - if not hasattr(self, 'weight'): - logging.warning("Warning: state dict on uninitialized op {}".format(prefix)) - return sd - - if self.bias is not None: - sd["{}bias".format(prefix)] = self.bias - - if self.weight is None: - return sd - - if isinstance(self.weight, QuantizedTensor): - sd_out = self.weight.state_dict("{}weight".format(prefix)) - for k in sd_out: - sd[k] = sd_out[k] - - quant_conf = {"format": self.quant_format} - if self._full_precision_mm_config: - quant_conf["full_precision_matrix_mult"] = True - sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) - - input_scale = getattr(self, 'input_scale', None) - if input_scale is not None: - sd["{}input_scale".format(prefix)] = input_scale - else: - sd["{}weight".format(prefix)] = self.weight - return sd + sd = destination if destination is not None else {} + return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",)) def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) @@ -1317,25 +1312,126 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter(weight, requires_grad=False) def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working - if recurse: - for module in self.children(): - module._apply(fn) + return _quantized_apply(self, fn, recurse) - for key, param in self._parameters.items(): - if param is None: - continue - p = fn(param) - if (not torch.is_inference_mode_enabled()) and p.is_inference(): - p = p.clone() - self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) - for key, buf in self._buffers.items(): - if buf is not None: - self._buffers[key] = fn(buf) - return self + class MoEExperts(torch.nn.Module, CastWeightBiasOp): + """Container for E quantized expert weights, indexed via expert_weight(i). + + The bank lives on self.weight as a single 3D tensor — either a + compute_dtype Parameter or a Parameter wrapping a QuantizedTensor + with leading expert dim. + + State-dict layout matches mixed_precision_ops.Linear with a leading + expert dim: + {prefix}.weight quant data (storage_t), leading dim = E + {prefix}.weight_scale block / per-tensor scale + {prefix}.weight_scale_2 [E] or scalar NVFP4 only + {prefix}.bias [E, out_features] optional, compute_dtype + {prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}} + + Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in]. + """ + + _disabled_formats = disabled + + def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): + super().__init__() + self.num_experts = num_experts + self.in_features = in_features + self.out_features = out_features + self._orig_shape = (num_experts, out_features, in_features) + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + if bias: + self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) + + # Populated by _load_from_state_dict: + self.weight = None + self.quant_format = None + self.layout_type = None + self._full_precision_mm = MixedPrecisionOps._full_precision_mm + self._full_precision_mm_config = False + self._resident_bank = None + + def reset_parameters(self): + return None + + def _apply(self, fn, recurse=True): + return _quantized_apply(self, fn, recurse) + + def _load_from_state_dict(self, *args): + _load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False) + + def expert_weight(self, i: int): + """Expert i's weight (Tensor or per-expert QuantizedTensor view).""" + if isinstance(self.weight, QuantizedTensor): + return self._expert_qt_from(self.weight, i) + return self.weight[i] + + @contextlib.contextmanager + def bank_resident(self, input): + """Cast the whole bank once; expert_linear inside reuses the cast. + Not re-entrant — do not nest calls on the same instance. + """ + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + self._resident_bank = (weight, bias) + try: + yield self + finally: + self._resident_bank = None + uncast_bias_weight(self, weight, bias, offload_stream) + + def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor: + """Linear against expert i's weight (with optional bias).""" + resident = getattr(self, "_resident_bank", None) + if resident is not None: + weight, bias = resident + return self._expert_linear_impl(input, weight, bias, i) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + try: + return self._expert_linear_impl(input, weight, bias, i) + finally: + uncast_bias_weight(self, weight, bias, offload_stream) + + def _expert_linear_impl(self, input, weight, bias, i): + if isinstance(weight, QuantizedTensor): + qw = self._expert_qt_from(weight, i) + else: + qw = weight[i] + b = cast_to_input(bias[i], input, copy=False) if bias is not None else None + + if isinstance(qw, QuantizedTensor): + use_fast = ( + not self._full_precision_mm + and qw.layout_cls.supports_fast_matmul() + and input.dim() == 2 + ) + if use_fast: + qin = QuantizedTensor.from_float(input, self.layout_type) + return torch.nn.functional.linear(qin, qw, b) + out = input @ qw.dequantize().t() + return out + b if b is not None else out + return torch.nn.functional.linear(input, qw, b) + + def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor: + """Build a per-expert QuantizedTensor by indexing into a resident bank.""" + params = weight._params + kwargs = { + "scale": params.scale[i] if params.scale.dim() else params.scale, + "orig_dtype": params.orig_dtype, + "orig_shape": (self.out_features, self.in_features), + } + if hasattr(params, "block_scale"): # NVFP4 + kwargs["block_scale"] = params.block_scale[i] + return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs)) + + def state_dict(self, *args, destination=None, prefix="", **kwargs): + sd = destination if destination is not None else {} + return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts}) class Embedding(manual_cast.Embedding): - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): weight_key = f"{prefix}weight" layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) if layer_conf is not None: @@ -1343,14 +1439,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec # Only fp8 makes sense for embeddings (per-row dequant via index select). # Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently. - quant_format = layer_conf.get("format", None) if layer_conf is not None else None - if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict: + quant_format = layer_conf.get("format") if layer_conf is not None else None + manually_loaded_keys = [] + + if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict: self.quant_format = quant_format qconfig = QUANT_ALGOS[quant_format] self.layout_type = qconfig["comfy_tensor_layout"] layout_cls = get_layout_class(self.layout_type) weight = state_dict.pop(weight_key) - manually_loaded_keys = [weight_key] + manually_loaded_keys.append(weight_key) scale_key = f"{prefix}weight_scale" scale = state_dict.pop(scale_key, None) @@ -1366,35 +1464,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter( QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params), requires_grad=False) + elif layer_conf is not None: + # Unsupported format — restore the marker so it round-trips; fall through to default load. + state_dict[f"{prefix}comfy_quant"] = torch.tensor( + list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - for k in manually_loaded_keys: - if k in missing_keys: - missing_keys.remove(k) - else: - if layer_conf is not None: - state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for k in manually_loaded_keys: + if k in missing_keys: + missing_keys.remove(k) def state_dict(self, *args, destination=None, prefix="", **kwargs): - if destination is not None: - sd = destination - else: - sd = {} - - if not hasattr(self, 'weight') or self.weight is None: - return sd - - if isinstance(self.weight, QuantizedTensor): - sd_out = self.weight.state_dict("{}weight".format(prefix)) - for k in sd_out: - sd[k] = sd_out[k] - - quant_conf = {"format": self.quant_format} - sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) - else: - sd["{}weight".format(prefix)] = self.weight - return sd + sd = destination if destination is not None else {} + return _quantized_weight_state_dict(self, sd, prefix) def forward_comfy_cast_weights(self, input, out_dtype=None): weight = self.weight diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py index 5ee4d5ee5..189ee84ca 100644 --- a/comfy/patcher_extension.py +++ b/comfy/patcher_extension.py @@ -1,8 +1,9 @@ -from __future__ import annotations from typing import Callable class CallbacksMP: ON_CLONE = "on_clone" + ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu" + ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones" ON_LOAD = "on_load_after" ON_DETACH = "on_detach_after" ON_CLEANUP = "on_cleanup" diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 3782fd2d5..bdce2f2d8 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,16 +1,18 @@ from __future__ import annotations +import torch import uuid import math import collections import comfy.model_management import comfy.conds +import comfy.model_patcher import comfy.utils import comfy.hooks import comfy.patcher_extension from typing import TYPE_CHECKING if TYPE_CHECKING: - from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel + from comfy.model_patcher import ModelPatcher from comfy.controlnet import ControlBase def prepare_mask(noise_mask, shape, device): @@ -119,6 +121,47 @@ def cleanup_additional_models(models): if hasattr(m, 'cleanup'): m.cleanup() +def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]): + '''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.''' + multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu") + if len(multigpu_models) == 0: + return + extra_devices = [x.load_device for x in multigpu_models] + # handle controlnets + controlnets: set[ControlBase] = set() + for k in conds: + for kk in conds[k]: + if 'control' in kk: + controlnets.add(kk['control']) + if len(controlnets) > 0: + # first, unload all controlnet clones + for cnet in list(controlnets): + cnet_models = cnet.get_models() + for cm in cnet_models: + comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True) + + # next, make sure each controlnet has a deepclone for all relevant devices + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + for device in extra_devices: + if device not in curr_cnet.multigpu_clones: + curr_cnet.deepclone_multigpu(device, autoregister=True) + curr_cnet = curr_cnet.previous_controlnet + # since all device clones are now present, recreate the linked list for cloned cnets per device + for cnet in controlnets: + curr_cnet = cnet + while curr_cnet is not None: + prev_cnet = curr_cnet.previous_controlnet + for device in extra_devices: + device_cnet = curr_cnet.get_instance_for_device(device) + prev_device_cnet = None + if prev_cnet is not None: + prev_device_cnet = prev_cnet.get_instance_for_device(device) + device_cnet.set_previous_controlnet(prev_device_cnet) + curr_cnet = prev_cnet + # potentially handle gligen - since not widely used, ignored for now + def estimate_memory(model, noise_shape, conds): cond_shapes = collections.defaultdict(list) cond_shapes_min = {} @@ -143,7 +186,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload) def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False): - real_model: BaseModel = None + model.match_multigpu_clones() + preprocess_multigpu_conds(conds, model, model_options) models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? @@ -155,7 +199,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non memory_required += inference_memory minimum_memory_required += inference_memory comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) - real_model = model.model + real_model: BaseModel = model.model return real_model, conds, models @@ -201,3 +245,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict): comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], copy_dict1=False) return to_load_options + +def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict): + ''' + In case multigpu acceleration is enabled, prep ModelPatchers for each device. + ''' + multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone] + if len(multigpu_patchers) > 0: + multigpu_dict: dict[torch.device, ModelPatcher] = {} + multigpu_dict[model_patcher.load_device] = model_patcher + for x in multigpu_patchers: + x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True) + x.hook_mode = model_patcher.hook_mode # match main model's hook_mode + multigpu_dict[x.load_device] = x + model_options["multigpu_clones"] = multigpu_dict + return multigpu_patchers diff --git a/comfy/samplers.py b/comfy/samplers.py index c5e36ff05..e31277f7b 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -1,7 +1,9 @@ from __future__ import annotations + +import comfy.model_management from .k_diffusion import sampling as k_diffusion_sampling from .extra_samplers import uni_pc -from typing import TYPE_CHECKING, Callable, NamedTuple +from typing import TYPE_CHECKING, Callable, NamedTuple, Any if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel @@ -16,6 +18,7 @@ import comfy.model_patcher import comfy.patcher_extension import comfy.hooks import comfy.context_windows +import comfy.multigpu import comfy.utils import scipy.stats import numpy @@ -141,7 +144,7 @@ def can_concat_cond(c1, c2): return cond_equal_size(c1.conditioning, c2.conditioning) -def cond_cat(c_list): +def cond_cat(c_list, device=None): temp = {} for x in c_list: for k in x: @@ -153,6 +156,8 @@ def cond_cat(c_list): for k in temp: conds = temp[k] out[k] = conds[0].concat(conds[1:]) + if device is not None and hasattr(out[k], 'to'): + out[k] = out[k].to(device) return out @@ -212,7 +217,12 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc ) return executor.execute(model, conds, x_in, timestep, model_options) -def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + # NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic + # (hooked_to_run accumulation, memory-fit batching, per-chunk output + # aggregation) is duplicated there with per-device scheduling layered on top. + if 'multigpu_clones' in model_options: + return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options) out_conds = [] out_counts = [] # separate conds by matching hooks @@ -244,7 +254,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens if has_default_conds: finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) - model.current_patcher.prepare_state(timestep) + model.current_patcher.prepare_state(timestep, model_options) # run every hooked_to_run separately for hooks, to_run in hooked_to_run.items(): @@ -344,6 +354,239 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens return out_conds +def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + # NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks + # accumulation, memory-fit batching, and output aggregation, but adds a + # per-device scheduler, per-device patcher/control lookup, tensor .to(device) + # placement, and MultiGPUThreadPool dispatch around the inner loop. + out_conds = [] + out_counts = [] + # separate conds by matching hooks + hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {} + default_conds = [] + has_default_conds = False + + output_device = x_in.device + + for i in range(len(conds)): + out_conds.append(torch.zeros_like(x_in)) + out_counts.append(torch.ones_like(x_in) * 1e-37) + + cond = conds[i] + default_c = [] + if cond is not None: + for x in cond: + if 'default' in x: + default_c.append(x) + has_default_conds = True + continue + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + if p.hooks is not None: + model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options) + hooked_to_run.setdefault(p.hooks, list()) + hooked_to_run[p.hooks] += [(p, i)] + default_conds.append(default_c) + + if has_default_conds: + finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) + + model.current_patcher.prepare_state(timestep, model_options) + + devices = list(model_options['multigpu_clones'].keys()) + device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} + # Track conds currently scheduled per device; single source of truth for capacity checks. + device_load: dict[torch.device, int] = {d: 0 for d in devices} + + total_conds = sum(len(to_run) for to_run in hooked_to_run.values()) + conds_per_device = max(1, math.ceil(total_conds / len(devices))) + + def next_available_device(start: int) -> tuple[int, torch.device]: + """Return (index, device) for the next device with remaining capacity, starting at `start`. + + Scans at most len(devices) positions, so this always terminates. Raises if no device + has remaining capacity, which would indicate a bug in conds_per_device accounting. + """ + for offset in range(len(devices)): + i = (start + offset) % len(devices) + if device_load[devices[i]] < conds_per_device: + return i, devices[i] + raise RuntimeError( + f"MultiGPU scheduler: all {len(devices)} devices at capacity " + f"({conds_per_device}) but conds remain to schedule" + ) + + # run every hooked_to_run separately + index_device = 0 + for hooks, to_run in hooked_to_run.items(): + while len(to_run) > 0: + index_device, current_device = next_available_device(index_device) + remaining_capacity = conds_per_device - device_load[current_device] + + first = to_run[0] + first_shape = first[0][0].shape + # collect candidate indices that can be concatenated with `first`, up to remaining capacity + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity: + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + free_memory = comfy.model_management.get_free_memory(current_device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + cond_shapes = collections.defaultdict(list) + for tt in batch_amount: + for k, v in to_run[tt][0].conditioning.items(): + cond_shapes[k].append(v.size()) + if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: + to_batch = batch_amount + break + + conds_to_batch = [to_run.pop(x) for x in to_batch] + device_load[current_device] += len(conds_to_batch) + device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch)) + + if device_load[current_device] >= conds_per_device: + index_device += 1 + + class thread_result(NamedTuple): + output: Any + mult: Any + area: Any + batch_chunks: int + cond_or_uncond: Any + error: Exception = None + + def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): + try: + # TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once + # we extend multigpu QA beyond CUDA. Unconditional call crashes on + # XPU/NPU/MPS/CPU/DirectML backends. + torch.cuda.set_device(device) + model_current: BaseModel = model_options["multigpu_clones"][device].model + # run every hooked_to_run separately + with torch.no_grad(): + for hooks, to_batch in batch_tuple: + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + uuids = [] + area = [] + control: ControlBase = None + patches = None + for x in to_batch: + o = x + p = o[0] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + uuids.append(p.uuid) + control = p.control + patches = p.patches + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x).to(device) + c = cond_cat(c, device=device) + timestep_ = torch.cat([timestep.to(device)] * batch_chunks) + + transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks) + if 'transformer_options' in model_options: + transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options, + model_options['transformer_options'], + copy_dict1=False) + + if patches is not None: + transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts( + transformer_options.get("patches", {}), + patches + ) + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["uuids"] = uuids[:] + transformer_options["sigmas"] = timestep.to(device) + transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device) + transformer_options["multigpu_thread_device"] = device + + cast_transformer_options(transformer_options, device=device) + c['transformer_options'] = transformer_options + + if control is not None: + device_control = control.get_instance_for_device(device) + c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options) + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) + else: + output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) + # TODO: non-NVIDIA support -- the `.to(output_device)` copies + # above are async on CUDA, so the main thread's aggregation + # could race with in-flight transfers. CUDA-only QA has not + # surfaced this in practice, but before extending multigpu + # beyond NVIDIA add a `torch.cuda.synchronize(output_device)` + # here (guarded by `output_device.type == "cuda"`). + results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) + except Exception as e: + results.append(thread_result(None, None, None, None, None, error=e)) + raise + + + def _handle_batch_pooled(device, batch_tuple): + worker_results = [] + _handle_batch(device, batch_tuple, worker_results) + return worker_results + + results: list[thread_result] = [] + thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool") + + # Submit all GPU work to pool threads + pool_devices = [] + for device, batch_tuple in device_batched_hooked_to_run.items(): + if thread_pool is not None: + thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple) + pool_devices.append(device) + else: + # Fallback: no pool, run everything on main thread + _handle_batch(device, batch_tuple, results) + + # Collect results from pool workers + for device in pool_devices: + worker_results, error = thread_pool.get_result(device) + if error is not None: + raise error + results.extend(worker_results) + + for output, mult, area, batch_chunks, cond_or_uncond, error in results: + if error is not None: + raise error + for o in range(batch_chunks): + cond_index = cond_or_uncond[o] + a = area[o] + if a is None: + out_conds[cond_index] += output[o] * mult[o] + out_counts[cond_index] += mult[o] + else: + out_c = out_conds[cond_index] + out_cts = out_counts[cond_index] + dims = len(a) // 2 + for i in range(dims): + out_c = out_c.narrow(i + 2, a[i + dims], a[i]) + out_cts = out_cts.narrow(i + 2, a[i + dims], a[i]) + out_c += output[o] * mult[o] + out_cts += mult[o] + + for i in range(len(out_conds)): + out_conds[i] /= out_counts[i] + + return out_conds + def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) @@ -642,12 +885,21 @@ def calculate_start_end_timesteps(model, conds): def pre_run_control(model, conds): s = model.model_sampling + # Per-device model lookup so multigpu control clones get the matching + # diffusion_model (e.g. QwenFunControlNet stashes it into extra_args). + device_models: dict = {} + patcher = getattr(model, "current_patcher", None) + if patcher is not None: + for p in patcher.get_additional_models_with_key("multigpu"): + device_models[p.load_device] = p.model for t in range(len(conds)): x = conds[t] percent_to_timestep_function = lambda a: s.percent_to_sigma(a) if 'control' in x: x['control'].pre_run(model, percent_to_timestep_function) + for device, device_cnet in x['control'].multigpu_clones.items(): + device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function) def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): cond_cnets = [] @@ -890,7 +1142,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): to_load_options = model_options.get("to_load_options", None) if to_load_options is None: return + cast_transformer_options(to_load_options, device, dtype) +def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None): casts = [] if device is not None: casts.append(device) @@ -899,18 +1153,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # if nothing to apply, do nothing if len(casts) == 0: return - # try to call .to on patches - if "patches" in to_load_options: - patches = to_load_options["patches"] + if "patches" in transformer_options: + patches = transformer_options["patches"] for name in patches: patch_list = patches[name] for i in range(len(patch_list)): if hasattr(patch_list[i], "to"): for cast in casts: patch_list[i] = patch_list[i].to(cast) - if "patches_replace" in to_load_options: - patches = to_load_options["patches_replace"] + if "patches_replace" in transformer_options: + patches = transformer_options["patches_replace"] for name in patches: patch_list = patches[name] for k in patch_list: @@ -920,8 +1173,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): # try to call .to on any wrappers/callbacks wrappers_and_callbacks = ["wrappers", "callbacks"] for wc_name in wrappers_and_callbacks: - if wc_name in to_load_options: - wc: dict[str, list] = to_load_options[wc_name] + if wc_name in transformer_options: + wc: dict[str, list] = transformer_options[wc_name] for wc_dict in wc.values(): for wc_list in wc_dict.values(): for i in range(len(wc_list)): @@ -929,7 +1182,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None): for cast in casts: wc_list[i] = wc_list[i].to(cast) - class CFGGuider: def __init__(self, model_patcher: ModelPatcher): self.model_patcher = model_patcher @@ -984,16 +1236,32 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - noise = noise.to(device=device, dtype=torch.float32) - latent_image = latent_image.to(device=device, dtype=torch.float32) - sigmas = sigmas.to(device) - cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options) - try: - self.model_patcher.pre_run() - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) - finally: - self.model_patcher.cleanup() + # Create persistent thread pool for all GPU devices (main + extras) + if multigpu_patchers: + extra_devices = [p.load_device for p in multigpu_patchers] + all_devices = [device] + extra_devices + self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices) + + with comfy.model_management.cuda_device_context(device): + try: + noise = noise.to(device=device, dtype=torch.float32) + latent_image = latent_image.to(device=device, dtype=torch.float32) + sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + + self.model_patcher.pre_run() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.pre_run() + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) + finally: + thread_pool = self.model_options.pop("multigpu_thread_pool", None) + if thread_pool is not None: + thread_pool.shutdown() + self.model_patcher.cleanup() + for multigpu_patcher in multigpu_patchers: + multigpu_patcher.cleanup() comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) del self.inner_model diff --git a/comfy/sd.py b/comfy/sd.py index 7bd07ed3a..30b877b85 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,4 +1,3 @@ -from __future__ import annotations import json import torch from enum import Enum @@ -50,6 +49,7 @@ import comfy.text_encoders.lt import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 +import comfy.text_encoders.pixeldit import comfy.text_encoders.wan import comfy.text_encoders.hidream import comfy.text_encoders.ace @@ -69,6 +69,7 @@ import comfy.text_encoders.ernie import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo import comfy.text_encoders.sa3 +import comfy.text_encoders.gpt_oss import comfy.model_patcher import comfy.lora @@ -335,41 +336,43 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model(tokens) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) + device = self.patcher.load_device + self.cond_stage_model.set_clip_options({"execution_device": device}) all_hooks.reset() self.patcher.patch_hooks(None) if show_pbar: pbar = ProgressBar(len(scheduled_keyframes)) - for scheduled_opts in scheduled_keyframes: - t_range = scheduled_opts[0] - # don't bother encoding any conds outside of start_percent and end_percent bounds - if "start_percent" in add_dict: - if t_range[1] < add_dict["start_percent"]: - continue - if "end_percent" in add_dict: - if t_range[0] > add_dict["end_percent"]: - continue - hooks_keyframes = scheduled_opts[1] - for hook, keyframe in hooks_keyframes: - hook.hook_keyframe._current_keyframe = keyframe - # apply appropriate hooks with values that match new hook_keyframe - self.patcher.patch_hooks(all_hooks) - # perform encoding as normal - o = self.cond_stage_model.encode_token_weights(tokens) - cond, pooled = o[:2] - pooled_dict = {"pooled_output": pooled} - # add clip_start_percent and clip_end_percent in pooled - pooled_dict["clip_start_percent"] = t_range[0] - pooled_dict["clip_end_percent"] = t_range[1] - # add/update any keys with the provided add_dict - pooled_dict.update(add_dict) - # add hooks stored on clip - self.add_hooks_to_dict(pooled_dict) - all_cond_pooled.append([cond, pooled_dict]) - if show_pbar: - pbar.update(1) - model_management.throw_exception_if_processing_interrupted() + with model_management.cuda_device_context(device): + for scheduled_opts in scheduled_keyframes: + t_range = scheduled_opts[0] + # don't bother encoding any conds outside of start_percent and end_percent bounds + if "start_percent" in add_dict: + if t_range[1] < add_dict["start_percent"]: + continue + if "end_percent" in add_dict: + if t_range[0] > add_dict["end_percent"]: + continue + hooks_keyframes = scheduled_opts[1] + for hook, keyframe in hooks_keyframes: + hook.hook_keyframe._current_keyframe = keyframe + # apply appropriate hooks with values that match new hook_keyframe + self.patcher.patch_hooks(all_hooks) + # perform encoding as normal + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] + pooled_dict = {"pooled_output": pooled} + # add clip_start_percent and clip_end_percent in pooled + pooled_dict["clip_start_percent"] = t_range[0] + pooled_dict["clip_end_percent"] = t_range[1] + # add/update any keys with the provided add_dict + pooled_dict.update(add_dict) + # add hooks stored on clip + self.add_hooks_to_dict(pooled_dict) + all_cond_pooled.append([cond, pooled_dict]) + if show_pbar: + pbar.update(1) + model_management.throw_exception_if_processing_interrupted() all_hooks.reset() return all_cond_pooled @@ -383,8 +386,12 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model(tokens) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) - o = self.cond_stage_model.encode_token_weights(tokens) + device = self.patcher.load_device + self.cond_stage_model.set_clip_options({"execution_device": device}) + + with model_management.cuda_device_context(device): + o = self.cond_stage_model.encode_token_weights(tokens) + cond, pooled = o[:2] if return_dict: out = {"cond": cond, "pooled_output": pooled} @@ -446,9 +453,12 @@ class CLIP: self.cond_stage_model.reset_clip_options() self.load_model(tokens) + device = self.patcher.load_device self.cond_stage_model.set_clip_options({"layer": None}) - self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) - return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) + self.cond_stage_model.set_clip_options({"execution_device": device}) + + with model_management.cuda_device_context(device): + return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) def decode(self, token_ids, skip_special_tokens=True): return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) @@ -1026,50 +1036,52 @@ class VAE: do_tile = False if self.latent_dim == 2 and samples_in.ndim == 5: samples_in = samples_in[:, :, 0] - try: - memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = self.patcher.get_free_memory(self.device) - batch_number = int(free_memory / memory_used) - batch_number = max(1, batch_number) - # Pre-allocate output for VAEs that support direct buffer writes - preallocated = False - if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): - pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) - preallocated = True + with model_management.cuda_device_context(self.device): + try: + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + free_memory = self.patcher.get_free_memory(self.device) + batch_number = int(free_memory / memory_used) + batch_number = max(1, batch_number) - for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) - if preallocated: - self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) - else: - out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) - if pixel_samples is None: - pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - pixel_samples[x:x+batch_number].copy_(out) - del out - self.process_output(pixel_samples[x:x+batch_number]) - except Exception as e: - model_management.raise_non_oom(e) - logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") - #NOTE: We don't know what tensors were allocated to stack variables at the time of the - #exception and the exception itself refs them all until we get out of this except block. - #So we just set a flag for tiler fallback so that tensor gc can happen once the - #exception is fully off the books. - do_tile = True + # Pre-allocate output for VAEs that support direct buffer writes + preallocated = False + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) + preallocated = True - if do_tile: - comfy.model_management.soft_empty_cache() - dims = samples_in.ndim - 2 - if dims == 1 or self.extra_1d_channel is not None: - pixel_samples = self.decode_tiled_1d(samples_in) - elif dims == 2: - pixel_samples = self.decode_tiled_(samples_in) - elif dims == 3: - tile = 256 // self.spacial_compression_decode() - overlap = tile // 4 - pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + for x in range(0, samples_in.shape[0], batch_number): + samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) + if preallocated: + self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) + else: + out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) + if pixel_samples is None: + pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + pixel_samples[x:x+batch_number].copy_(out) + del out + self.process_output(pixel_samples[x:x+batch_number]) + except Exception as e: + model_management.raise_non_oom(e) + logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: + comfy.model_management.soft_empty_cache() + dims = samples_in.ndim - 2 + if dims == 1 or self.extra_1d_channel is not None: + pixel_samples = self.decode_tiled_1d(samples_in) + elif dims == 2: + pixel_samples = self.decode_tiled_(samples_in) + elif dims == 3: + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples @@ -1087,20 +1099,21 @@ class VAE: if overlap is not None: args["overlap"] = overlap - if dims == 1 or self.extra_1d_channel is not None: - args.pop("tile_y") - output = self.decode_tiled_1d(samples, **args) - elif dims == 2: - output = self.decode_tiled_(samples, **args) - elif dims == 3: - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (max(1, overlap_t), overlap, overlap) - if tile_t is not None: - args["tile_t"] = max(2, tile_t) + with model_management.cuda_device_context(self.device): + if dims == 1 or self.extra_1d_channel is not None: + args.pop("tile_y") + output = self.decode_tiled_1d(samples, **args) + elif dims == 2: + output = self.decode_tiled_(samples, **args) + elif dims == 3: + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (max(1, overlap_t), overlap, overlap) + if tile_t is not None: + args["tile_t"] = max(2, tile_t) - output = self.decode_tiled_3d(samples, **args) + output = self.decode_tiled_3d(samples, **args) return output.movedim(1, -1) def encode(self, pixel_samples): @@ -1113,44 +1126,46 @@ class VAE: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) else: pixel_samples = pixel_samples.unsqueeze(2) - try: - memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = self.patcher.get_free_memory(self.device) - batch_number = int(free_memory / max(1, memory_used)) - batch_number = max(1, batch_number) - samples = None - for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) - if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): - out = self.first_stage_model.encode(pixels_in, device=self.device) + + with model_management.cuda_device_context(self.device): + try: + memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + free_memory = self.patcher.get_free_memory(self.device) + batch_number = int(free_memory / max(1, memory_used)) + batch_number = max(1, batch_number) + samples = None + for x in range(0, pixel_samples.shape[0], batch_number): + pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + out = self.first_stage_model.encode(pixels_in, device=self.device) + else: + pixels_in = pixels_in.to(self.device) + out = self.first_stage_model.encode(pixels_in) + out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) + if samples is None: + samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + samples[x:x + batch_number] = out + + except Exception as e: + model_management.raise_non_oom(e) + logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. + do_tile = True + + if do_tile: + comfy.model_management.soft_empty_cache() + if self.latent_dim == 3: + tile = 256 + overlap = tile // 4 + samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + elif self.latent_dim == 1 or self.extra_1d_channel is not None: + samples = self.encode_tiled_1d(pixel_samples) else: - pixels_in = pixels_in.to(self.device) - out = self.first_stage_model.encode(pixels_in) - out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) - if samples is None: - samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - samples[x:x + batch_number] = out - - except Exception as e: - model_management.raise_non_oom(e) - logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") - #NOTE: We don't know what tensors were allocated to stack variables at the time of the - #exception and the exception itself refs them all until we get out of this except block. - #So we just set a flag for tiler fallback so that tensor gc can happen once the - #exception is fully off the books. - do_tile = True - - if do_tile: - comfy.model_management.soft_empty_cache() - if self.latent_dim == 3: - tile = 256 - overlap = tile // 4 - samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) - elif self.latent_dim == 1 or self.extra_1d_channel is not None: - samples = self.encode_tiled_1d(pixel_samples) - else: - samples = self.encode_tiled_(pixel_samples) + samples = self.encode_tiled_(pixel_samples) return samples @@ -1176,26 +1191,27 @@ class VAE: if overlap is not None: args["overlap"] = overlap - if dims == 1: - args.pop("tile_y") - samples = self.encode_tiled_1d(pixel_samples, **args) - elif dims == 2: - samples = self.encode_tiled_(pixel_samples, **args) - elif dims == 3: - if tile_t is not None: - tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) - else: - tile_t_latent = 9999 - args["tile_t"] = self.upscale_ratio[0](tile_t_latent) + with model_management.cuda_device_context(self.device): + if dims == 1: + args.pop("tile_y") + samples = self.encode_tiled_1d(pixel_samples, **args) + elif dims == 2: + samples = self.encode_tiled_(pixel_samples, **args) + elif dims == 3: + if tile_t is not None: + tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) + else: + tile_t_latent = 9999 + args["tile_t"] = self.upscale_ratio[0](tile_t_latent) - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) - maximum = pixel_samples.shape[2] - maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) + maximum = pixel_samples.shape[2] + maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) - samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) + samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) return samples @@ -1269,6 +1285,8 @@ class CLIPType(Enum): FLUX2 = 25 LONGCAT_IMAGE = 26 COGVIDEOX = 27 + LENS = 28 + PIXELDIT = 29 @@ -1321,6 +1339,7 @@ class TEModel(Enum): GEMMA_4_E2B = 30 GEMMA_4_31B = 31 T5_GEMMA = 32 + GPT_OSS_20B = 33 def detect_te_model(sd): @@ -1362,6 +1381,9 @@ def detect_te_model(sd): else: return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B + # Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512). + if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd: + return TEModel.GPT_OSS_20B if 'model.layers.0.self_attn.k_proj.bias' in sd: weight = sd['model.layers.0.self_attn.k_proj.bias'] if weight.shape[0] == 256: @@ -1508,8 +1530,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = variant.tokenizer tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.GEMMA_2_2B: - clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer + if clip_type == CLIPType.PIXELDIT: + clip_target.clip = comfy.text_encoders.pixeldit.pixeldit_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer + else: + clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.GEMMA_3_4B: clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b") @@ -1544,6 +1570,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) + elif te_model == TEModel.GPT_OSS_20B: + clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer + tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.QWEN3_4B: if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2: clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b") @@ -1710,12 +1740,52 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) - if output_model and out[0] is not None: - out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options)) - if output_clip and out[1] is not None: - out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options)) + if out[0] is not None: + out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) + # Register reload factories for the CLIP and VAE produced by the same checkpoint so + # ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device, + # MultiGPU work-units, etc.) without falling back to copy.deepcopy of an + # already-loaded module. + if out[1] is not None and getattr(out[1], "patcher", None) is not None: + out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options)) + if out[2] is not None and getattr(out[2], "patcher", None) is not None: + out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options)) return out + +def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + """Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init + factory for the CLIP returned by load_checkpoint_guess_config.""" + _, clip, _, _ = load_checkpoint_guess_config( + ckpt_path, + output_vae=False, + output_clip=True, + output_clipvision=False, + embedding_directory=embedding_directory, + output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic, + ) + return clip.patcher + + +def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + """Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init + factory for the VAE returned by load_checkpoint_guess_config.""" + _, _, vae, _ = load_checkpoint_guess_config( + ckpt_path, + output_vae=True, + output_clip=False, + output_clipvision=False, + embedding_directory=embedding_directory, + output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic, + ) + return vae.patcher + def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, embedding_directory=embedding_directory, @@ -1742,7 +1812,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) - load_device = model_management.get_torch_device() + load_device = model_options.get("load_device", model_management.get_torch_device()) custom_operations = model_options.get("custom_operations", None) if custom_operations is None: @@ -1782,13 +1852,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher - model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + offload_device = model_options.get("offload_device", model_management.unet_offload_device()) + model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = model_config.process_vae_state_dict(vae_sd) - vae = VAE(sd=vae_sd, metadata=metadata) + vae_device = model_options.get("load_device", None) + vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device) if output_clip: if te_model_options.get("custom_operations", None) is None: @@ -1872,7 +1944,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable parameters = comfy.utils.calculate_parameters(sd) weight_dtype = comfy.utils.weight_dtype(sd) - load_device = model_management.get_torch_device() + load_device = model_options.get("load_device", model_management.get_torch_device()) model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: @@ -1897,7 +1969,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable else: logging.warning("{} {}".format(diffusers_keys[k], k)) - offload_device = model_management.unet_offload_device() + offload_device = model_options.get("offload_device", model_management.unet_offload_device()) unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.quant_config is not None: weight_dtype = None @@ -1939,6 +2011,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False): model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model + +def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False): + """Reload a disk-backed VAE from ``vae_path`` and return its patcher. + + Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so + :meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a + fresh, untainted VAE patcher (no inherited per-device load state, no + in-place quantization fallout) for multigpu work-units and the + SelectVAEDevice node. The optional ``device`` matches the source loader's + VAE initialization path; the deepclone's ``load_device`` still controls + where the cloned patcher is targeted. + """ + if metadata is None: + sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) + else: + sd = comfy.utils.load_torch_file(vae_path) + vae = VAE(sd=sd, metadata=metadata, device=device) + vae.throw_exception_if_invalid() + return vae.patcher + def load_unet(unet_path, dtype=None): logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") return load_diffusion_model(unet_path, model_options={"dtype": dtype}) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 617db4f28..00941da53 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -30,6 +30,7 @@ import comfy.text_encoders.longcat_image import comfy.text_encoders.ernie import comfy.text_encoders.cogvideo import comfy.text_encoders.hidream_o1 +import comfy.text_encoders.pixeldit from . import supported_models_base from . import latent_formats @@ -829,6 +830,50 @@ class Flux2(Flux): return None + +class Lens(supported_models_base.BASE): + """Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE).""" + + unet_config = { + "image_model": "lens", + } + + sampling_settings = { + "shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300 + } + + unet_extra_config = {} + latent_format = latent_formats.Flux2 + + memory_usage_factor = 4.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def __init__(self, unet_config): + super().__init__(unet_config) + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device) + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + for hint in ("gpt_oss.transformer.", ""): + full_prefix = "{}{}".format(pref, hint) + if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict: + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix) + return supported_models_base.ClipTarget( + comfy.text_encoders.gpt_oss.LensTokenizer, + comfy.text_encoders.gpt_oss.lens_te(**detect), + ) + return supported_models_base.ClipTarget( + comfy.text_encoders.gpt_oss.LensTokenizer, + comfy.text_encoders.gpt_oss.lens_te(), + ) + + class GenmoMochi(supported_models_base.BASE): unet_config = { "image_model": "mochi_preview", @@ -1159,6 +1204,72 @@ class ZImagePixelSpace(ZImage): def get_model(self, state_dict, prefix="", device=None): return model_base.ZImagePixelSpace(self, device=device) +class PixelDiTT2I(supported_models_base.BASE): + unet_config = { + "image_model": "pixeldit_t2i", + } + + unet_extra_config = {} + + sampling_settings = { + "shift": 4.0, # 1024px stage 3 default; 2.0 for 512px + } + + latent_format = latent_formats.PixelDiTPixel + memory_usage_factor = 0.04 + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.PixelDiTT2I(self, device=device) + + def process_unet_state_dict(self, state_dict): + # pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim). + pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0] + + out = {} + marker = ".adaLN_modulation.0." + for k, v in state_dict.items(): + if k.startswith("_repa_projector") or k.startswith("net_ema."): + continue + if k.startswith("core."): + k = k[len("core."):] + elif k.startswith("net."): + k = k[len("net."):] + if "pixel_blocks." in k and marker in k: + # Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM + p2 = v.shape[0] // (6 * pixel_dim) + trail = v.shape[1:] # () for bias, (in_dim,) for weight + vv = v.view(p2, 6, pixel_dim, *trail) + base, suffix = k.split(marker) + out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous() + out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous() + else: + out[k] = v + return out + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget( + comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer, + comfy.text_encoders.pixeldit.PixelDiTGemma2TE, + ) + +class PiD(PixelDiTT2I): + unet_config = { + "image_model": "pid", + } + + sampling_settings = { + "shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0] + } + + memory_usage_factor = 0.04 + + def get_model(self, state_dict, prefix="", device=None): + return model_base.PiD(self, device=device) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -2069,6 +2180,8 @@ models = [ CosmosI2VPredict2, ZImagePixelSpace, ZImage, + PiD, + PixelDiTT2I, Lumina2, WAN22_T2V, WAN21_CausalAR_T2V, @@ -2096,6 +2209,7 @@ models = [ Omnigen2, QwenImage, Flux2, + Lens, Kandinsky5Image, Kandinsky5, Anima, diff --git a/comfy/text_encoders/gpt_oss.py b/comfy/text_encoders/gpt_oss.py new file mode 100644 index 000000000..d596ef9a0 --- /dev/null +++ b/comfy/text_encoders/gpt_oss.py @@ -0,0 +1,600 @@ +"""GPT-OSS text encoder for Lens.""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +from comfy import sd1_clip +from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device +from comfy.text_encoders.llama import RMSNorm, apply_rope + + +@dataclass +class GptOss20BConfig: + vocab_size: int = 201088 + hidden_size: int = 2880 + intermediate_size: int = 2880 + num_hidden_layers: int = 24 + num_attention_heads: int = 64 + num_key_value_heads: int = 8 + head_dim: int = 64 + num_local_experts: int = 32 + num_experts_per_tok: int = 4 + sliding_window: int = 128 + original_max_position_embeddings: int = 4096 + rope_theta: float = 150000.0 + rope_factor: float = 32.0 + rope_beta_fast: float = 32.0 + rope_beta_slow: float = 1.0 + rope_truncate: bool = False + rms_norm_eps: float = 1e-5 + attention_bias: bool = True + layer_types: Optional[List[str]] = None + moe_alpha: float = 1.702 + moe_limit: float = 7.0 + + def __post_init__(self): + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if (i + 1) % 2 else "full_attention" + for i in range(self.num_hidden_layers) + ] + + +def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float, + original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]: + """YARN inv_freq + attention scaling (matches transformers).""" + dim = head_dim + + def find_correction_dim(num_rotations: float) -> float: + return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + def find_correction_range() -> tuple[float, float]: + low = find_correction_dim(beta_fast) + high = find_correction_dim(beta_slow) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor: + if min_ == max_: + max_ += 0.001 + linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_) + return torch.clamp(linear, 0, 1) + + def get_mscale(scale: float) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + attention_scaling = get_mscale(factor) + + pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range() + extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2) + inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor + return inv_freq, attention_scaling + + +def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + pos_e = position_ids[:, None, :].float() + freqs = (inv_freq_e @ pos_e).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1) + sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1) + sin_split = sin.shape[-1] // 2 + return cos, sin[..., :sin_split], -sin[..., sin_split:] + + +def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor, + attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor: + """Attention with per-head sinks. + + Sinks add a learned term to each row's softmax denominator but contribute + nothing to the output. We fake this by appending one zero k/v position and + putting the sink logit in the mask at that column. + """ + + if num_kv_groups > 1 and not TORCH_HAS_GQA: + k = k.repeat_interleave(num_kv_groups, dim=1) + v = v.repeat_interleave(num_kv_groups, dim=1) + + B, _, S_q, D = q.shape + H_kv = k.shape[1] + S_kv = k.shape[-2] + + k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2) + v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2) + + sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1) + if attention_mask is not None: + mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv) + else: + mask_left = q.new_zeros(B, num_heads, S_q, S_kv) + mask = torch.cat([mask_left, sinks_col], dim=-1) + + op = optimized_attention_for_device(q.device, mask=True, small_input=True) + return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True) + + +class GptOssAttention(nn.Module): + def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None): + super().__init__() + self.layer_idx = layer_idx + self.layer_type = config.layer_types[layer_idx] + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + + bias = config.attention_bias + self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype) + self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype) + self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype) + self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype) + self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype)) + + def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor: + B, S, _ = hidden_states.shape + + q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) + + q, k = apply_rope(q, k, freqs_cis) + + out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups) + return self.o_proj(out) + + +# Mixture of Experts + +class GptOssTopKRouter(nn.Module): + def __init__(self, config: GptOss20BConfig, device=None, dtype=None): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype)) + self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype)) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False) + bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False) + logits = F.linear(hidden_states, weight, bias) + top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1) + # Softmax over top-k slice only + scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype) + return scores, top_idx + + +class GptOssExperts(nn.Module): + def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.alpha = config.moe_alpha + self.limit = config.moe_limit + + E = self.num_experts + H = self.hidden_size + I = self.intermediate_size + + self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype) + self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype) + + def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: + gate = gate_up[..., ::2] + up = gate_up[..., 1::2] + gate = gate.clamp(max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + return torch.addcmul(glu, up, glu) + + def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + N = hidden_states.shape[0] + top_k = router_indices.shape[-1] + H = hidden_states.shape[-1] + + per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device) + + expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \ + self.down_proj.bank_resident(hidden_states) as down_bank: + for ei in expert_hit: + expert_idx = int(ei.item()) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current = hidden_states[token_idx] + + gate_up = gate_up_bank.expert_linear(current, expert_idx) + gated = self._apply_gate(gate_up) + expert_out = down_bank.expert_linear(gated, expert_idx) + + weighted = expert_out * routing_weights[token_idx, top_k_pos, None] + + flat_idx = token_idx * top_k + top_k_pos + per_pair[flat_idx] = weighted.to(per_pair.dtype) + + return per_pair.view(N, top_k, H).sum(dim=1) + + +class GptOssMLP(nn.Module): + def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None): + super().__init__() + self.router = GptOssTopKRouter(config, device=device, dtype=dtype) + self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, H = hidden_states.shape + flat = hidden_states.reshape(-1, H) + scores, idx = self.router(flat) + out = self.experts(flat, idx, scores) + return out.reshape(B, S, H) + + +# Decoder layer + model + +class GptOssDecoderLayer(nn.Module): + def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None): + super().__init__() + self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops) + self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.layer_type = config.layer_types[layer_idx] + + def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor: + residual = x + x = self.input_layernorm(x) + x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis) + x = residual + x + + residual = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + x = residual + x + return x + + +def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device): + neg = torch.finfo(dtype).min + mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1) + mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous() + if key_padding_mask is not None: + kp = key_padding_mask.to(dtype=dtype) + kp = (1.0 - kp).reshape(B, 1, 1, S) * neg + mask = mask + kp + return mask + + +def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device): + neg = torch.finfo(dtype).min + i = torch.arange(S, device=device).view(-1, 1) + j = torch.arange(S, device=device).view(1, -1) + keep = (j <= i) & (j > i - window) + mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device)) + mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous() + if key_padding_mask is not None: + kp = key_padding_mask.to(dtype=dtype) + kp = (1.0 - kp).reshape(B, 1, 1, S) * neg + mask = mask + kp + return mask + + +class GptOssModel(nn.Module): + """GPT-OSS decoder with multi-layer hidden-state capture + early exit.""" + + def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None): + super().__init__() + self.config = config + self.dtype = dtype + self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) + self.layers = nn.ModuleList( + [ + GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + + # Always build on CPU so the buffer survives meta-device construction. + inv_freq, attn_scaling = _yarn_inv_freq( + head_dim=config.head_dim, + base=config.rope_theta, + factor=config.rope_factor, + beta_fast=config.rope_beta_fast, + beta_slow=config.rope_beta_slow, + original_max_position_embeddings=config.original_max_position_embeddings, + truncate=config.rope_truncate, + device=torch.device("cpu"), + ) + self.register_buffer("rope_inv_freq", inv_freq, persistent=False) + self.rope_attention_scaling = float(attn_scaling) + + @property + def num_layers(self) -> int: + return self.config.num_hidden_layers + + def get_input_embeddings(self): + return self.embed_tokens + + def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device, + ) -> dict[str, torch.Tensor]: + full = _make_full_causal_mask(B, S, attention_mask, dtype, device) + masks = {"full_attention": full} + if any(t == "sliding_attention" for t in self.config.layer_types): + masks["sliding_attention"] = _make_sliding_causal_mask( + B, S, self.config.sliding_window, attention_mask, dtype, device + ) + return masks + + def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, + capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]: + B, S = input_ids.shape + device = input_ids.device + dtype = self.dtype + + hidden_states = self.embed_tokens(input_ids, out_dtype=dtype) + + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) + freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype) + + attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device) + + capture_layers = list(capture_layers) if capture_layers else None + if capture_layers: + max_layer = max(capture_layers) + wanted = {idx: pos for pos, idx in enumerate(capture_layers)} + captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers) + else: + max_layer = self.config.num_hidden_layers - 1 + wanted = None + captured = None + + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, attn_masks, freqs_cis) + if wanted is not None and i in wanted: + captured[wanted[i]] = hidden_states + if i >= max_layer: + break + + if captured is not None: + return {"hidden_states": captured} + return {"last_hidden_state": self.norm(hidden_states)} + + +# Lens chat-template constants (verbatim from the reference pipeline). +_LENS_CHAT_SYSTEM = ( + "Describe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background." +) +_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description." +LENS_TXT_OFFSET = 97 +LENS_SELECTED_LAYERS = (5, 11, 17, 23) +LENS_MAX_TOKENS = 512 + + +# The reference GPT-OSS Harmony template injects today's date here +_LENS_CHAT_DATE = "2026-05-23" + + +def _lens_render_chat(prompt: str) -> str: + """Render the Lens prompt in GPT-OSS Harmony format.""" + return ( + f"<|start|>system<|message|>" + f"You are ChatGPT, a large language model trained by OpenAI.\n" + f"Knowledge cutoff: 2024-06\n" + f"Current date: {_LENS_CHAT_DATE}\n\n" + f"Reasoning: medium\n\n" + f"# Valid channels: analysis, commentary, final. " + f"Channel must be included for every message.<|end|>" + f"<|start|>developer<|message|># Instructions\n\n" + f"{_LENS_CHAT_SYSTEM}\n\n<|end|>" + f"<|start|>user<|message|>{prompt}<|end|>" + f"<|start|>assistant<|channel|>analysis<|message|>" + f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>" + f"<|start|>assistant<|channel|>final<|message|>" + ) + + +# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table). +_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|> + + +class _GptOssRawTokenizer: + """Raw ``tokenizers.Tokenizer`` wrapper. + + The tokenizer JSON ships as a byte tensor inside the encoder checkpoint + (``tokenizer_json`` key) rather than as a committed file. Extracted + it in ``sd.py`` and passes it here via ``tokenizer_data``. + """ + + def __init__(self, tokenizer_json_bytes=None, **kwargs): + from tokenizers import Tokenizer + if isinstance(tokenizer_json_bytes, torch.Tensor): + tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist()) + if tokenizer_json_bytes is None: + raise ValueError( + "Lens tokenizer requires the ``tokenizer_json`` byte tensor in the " + "encoder state dict. Re-bundle the encoder via bundle_te.py so it " + "embeds the tokenizer." + ) + self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8")) + + @classmethod + def from_pretrained(cls, tokenizer_data, **kwargs): + return cls(tokenizer_json_bytes=tokenizer_data, **kwargs) + + def __call__(self, text): + return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids} + + def get_vocab(self): + return self.tokenizer.get_vocab() + + def convert_tokens_to_ids(self, tokens): + return [self.tokenizer.token_to_id(t) for t in tokens] + + def decode(self, ids, **kwargs): + return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False)) + + +class LensGptOssTokenizer(sd1_clip.SDTokenizer): + tokenizer_json_data = None + + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_json = tokenizer_data.get("tokenizer_json", None) + self.tokenizer_json_data = tokenizer_json + super().__init__( + tokenizer_json, + embedding_directory=embedding_directory, + pad_with_end=False, + embedding_size=2880, + embedding_key="gpt_oss", + tokenizer_class=_GptOssRawTokenizer, + has_start_token=False, + has_end_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=1, + pad_left=False, + disable_weights=True, + tokenizer_data=tokenizer_data, + ) + self.pad_token_id = _LENS_PAD_TOKEN_ID + + def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs): + # Empty prompt -> empty list; encode_token_weights returns zeros (uncond). + if not text or not text.strip(): + return [[]] + rendered = _lens_render_chat(text) + ids = self.tokenizer(rendered)["input_ids"] + if len(ids) > LENS_MAX_TOKENS: + ids = ids[:LENS_MAX_TOKENS] + return [[(int(t), 1.0) for t in ids]] + + def state_dict(self): + if self.tokenizer_json_data is not None: + return {"tokenizer_json": self.tokenizer_json_data} + return {} + + +class LensTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__( + embedding_directory=embedding_directory, + tokenizer_data=tokenizer_data, + name="gpt_oss", + tokenizer=LensGptOssTokenizer, + ) + + +class LensGptOssClipModel(nn.Module): + """SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor).""" + + def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs): + super().__init__() + model_options = dict(model_options or {}) + + operations = model_options.get("custom_operations") + if operations is None: + quant_config = model_options.get("quantization_metadata") or {} + operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True) + self.operations = operations + + cfg_overrides = model_options.get("gpt_oss_config", {}) + self.config = GptOss20BConfig(**cfg_overrides) + self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS)) + self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET)) + + self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations) + self.num_layers = self.config.num_hidden_layers + self.dtype = dtype + self.execution_device = None + self._pad_token_id = _LENS_PAD_TOKEN_ID + + def set_clip_options(self, options): + self.execution_device = options.get("execution_device", self.execution_device) + + def reset_clip_options(self): + self.execution_device = None + + def _gather_tokens(self, token_weight_pairs): + ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs] + pad_id = self._pad_token_id + max_len = max(len(x) for x in ids_list) + device = self.execution_device + ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device) + mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device) + for i, x in enumerate(ids_list): + ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device) + mask[i, : len(x)] = 1 + return ids, mask + + def encode_token_weights(self, token_weight_pairs): + # Empty negative: emit zero-length features + zero mask + if all(len(batch) == 0 for batch in token_weight_pairs): + device = self.execution_device + B = len(token_weight_pairs) + L = len(self.selected_layers) + H = self.config.hidden_size + flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device) + mask = torch.zeros(B, 0, dtype=torch.long, device=device) + return flat, None, {"attention_mask": mask, "num_layers_stacked": L} + + input_ids, attn_mask = self._gather_tokens(token_weight_pairs) + out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers) + layers = out["hidden_states"] # list of L × [B, S, H] + stacked = torch.stack(layers, dim=2) # [B, S, L, H] + + offset = self.txt_offset + if stacked.shape[1] > offset: + stacked = stacked[:, offset:].contiguous() + mask_trim = attn_mask[:, offset:] + else: + stacked = stacked[:, :0] + mask_trim = attn_mask[:, :0] + + B, S, L, H = stacked.shape + flat = stacked.reshape(B, S, L * H) + extra = {"attention_mask": mask_trim, "num_layers_stacked": L} + return flat, None, extra + + def load_sd(self, sd): + return self.transformer.load_state_dict(sd, strict=False, assign=True) + + +class LensTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options=None): + super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {}) + + +def lens_te(dtype_llama=None, llama_quantization_metadata=None): + class LensTEModel_(LensTEModel): + def __init__(self, device="cpu", dtype=None, model_options=None): + mo = dict(model_options or {}) + if llama_quantization_metadata is not None: + mo["quantization_metadata"] = llama_quantization_metadata + if dtype is None and dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=mo) + + return LensTEModel_ diff --git a/comfy/text_encoders/pixeldit.py b/comfy/text_encoders/pixeldit.py new file mode 100644 index 000000000..3539711e4 --- /dev/null +++ b/comfy/text_encoders/pixeldit.py @@ -0,0 +1,104 @@ +import torch + +from comfy import sd1_clip +from .lumina2 import Gemma2BTokenizer, LuminaModel +import comfy.text_encoders.llama + + +class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__( + device=device, layer=layer, layer_idx=layer_idx, + textmodel_json_config={}, dtype=dtype, + special_tokens={"start": 2, "pad": 0}, + layer_norm_hidden_state=False, + model_class=comfy.text_encoders.llama.Gemma2_2B, + enable_attention_masks=attention_mask, + return_attention_masks=attention_mask, + model_options=model_options, + ) + + +_PIXELDIT_CHI_PROMPT = ( + 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions ' + "suitable for image generation. Evaluate the level of detail in the user prompt:\n" + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, " + "and spatial relationships to create vivid and concrete scenes.\n" + "- If the prompt is already detailed, refine and enhance the existing details slightly without " + "overcomplicating.\n" + "Here are examples of how to transform or refine prompts:\n" + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, " + "sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n" + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring " + "glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus " + "passing by towering glass skyscrapers.\n" + "Please generate only the enhanced description for the prompt below and avoid including any " + "additional commentary or evaluations:\n" + "User Prompt: " +) + +_PIXELDIT_MAX_LENGTH = 300 +_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"' + + +class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, + name="gemma2_2b", tokenizer=Gemma2BTokenizer) + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + if not text.strip(): + return super().tokenize_with_weights("", return_word_ids=return_word_ids, disable_weights=True, min_length=_PIXELDIT_MAX_LENGTH) + + chi_token_count = len(self.gemma2_2b.tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"]) + combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text + max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2 + out = super().tokenize_with_weights(combined, return_word_ids=return_word_ids, + disable_weights=True, min_length=max_length_all) + out["gemma2_2b"] = [out["gemma2_2b"][0][:max_length_all]] + return out + + def untokenize(self, token_weight_pair): + return self.gemma2_2b.untokenize(token_weight_pair) + + def state_dict(self): + return self.gemma2_2b.state_dict() + + +class PixelDiTGemma2TE(LuminaModel): + # PixelDiT's select_index: keep BOS + last 299 embeddings of the padded sequence. + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="gemma2_2b", + clip_model=PixelDiTGemma2_2BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + result = super().encode_token_weights(token_weight_pairs) + cond, pooled = result[0], result[1] + extra = result[2] if len(result) > 2 else None + if cond.shape[1] > _PIXELDIT_MAX_LENGTH: + cond = torch.cat([cond[:, :1], cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]], dim=1) + if extra is not None and "attention_mask" in extra: + am = extra["attention_mask"] + extra["attention_mask"] = torch.cat([am[..., :1], am[..., -(_PIXELDIT_MAX_LENGTH - 1):]], dim=-1) + if extra is not None: + return cond, pooled, extra + return cond, pooled + + +def pixeldit_te(dtype_llama=None, llama_quantization_metadata=None): + class PixelDiTTE_(PixelDiTGemma2TE): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return PixelDiTTE_ diff --git a/comfy/utils.py b/comfy/utils.py index 31052714a..49ae12b06 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -86,6 +86,7 @@ def load_safetensors(ckpt): import comfy_aimdo.model_mmap f = open(ckpt, "rb", buffering=0) + file_lock = threading.Lock() model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt) file_size = os.path.getsize(ckpt) mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get())) @@ -111,7 +112,7 @@ def load_safetensors(ckpt): storage = tensor.untyped_storage() setattr(storage, "_comfy_tensor_file_slice", - comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) + comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) sd[name] = tensor diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 04973fea0..e0a585b10 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from abc import ABC, abstractmethod from typing import TYPE_CHECKING from comfy_api.internal import ComfyAPIBase diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 942278d88..99e67d363 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -1,4 +1,3 @@ -from __future__ import annotations from av.container import InputContainer from av.subtitles.stream import SubtitleStream from fractions import Fraction diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py index c92477f08..6c9d6a526 100644 --- a/comfy_api/latest/_util/video_types.py +++ b/comfy_api/latest/_util/video_types.py @@ -1,4 +1,3 @@ -from __future__ import annotations from dataclasses import dataclass from enum import Enum from fractions import Fraction diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index 46a583b5e..9c4cfb9b6 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -3,7 +3,6 @@ # timestamp: 2025-07-30T08:54:00+00:00 # pylint: disable -from __future__ import annotations from datetime import date, datetime from enum import Enum diff --git a/comfy_api_nodes/apis/bfl.py b/comfy_api_nodes/apis/bfl.py index d8d3557b3..f0665fa09 100644 --- a/comfy_api_nodes/apis/bfl.py +++ b/comfy_api_nodes/apis/bfl.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from enum import Enum from typing import Any, Dict, Optional diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index 03f4c445b..47f24586c 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -158,8 +158,9 @@ class SeedanceCreateAssetResponse(BaseModel): class SeedanceVirtualLibraryCreateAssetRequest(BaseModel): - url: str = Field(..., description="Publicly accessible URL of the image asset to upload.") + url: str = Field(..., description="Publicly accessible URL of the asset to upload.") hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.") + asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.") # Dollars per 1K tokens, keyed by (model_id, has_video_input). diff --git a/comfy_api_nodes/apis/krea.py b/comfy_api_nodes/apis/krea.py new file mode 100644 index 000000000..6e294a3b7 --- /dev/null +++ b/comfy_api_nodes/apis/krea.py @@ -0,0 +1,46 @@ +"""Pydantic models for the Krea image-generation API.""" + +from pydantic import BaseModel, Field + + +class KreaMoodboard(BaseModel): + id: str = Field(...) + strength: float = Field(default=0.35, ge=-0.5, le=1.5) + + +class KreaImageStyleReference(BaseModel): + strength: float = Field(..., ge=-2.0, le=2.0) + url: str | None = Field(default=None) + + +class KreaGenerateImageRequest(BaseModel): + prompt: str = Field(...) + aspect_ratio: str = Field(...) + resolution: str = Field(...) + seed: int | None = Field(default=None) + creativity: str = Field(default="medium") + moodboards: list[KreaMoodboard] | None = Field(default=None) + image_style_references: list[KreaImageStyleReference] | None = Field(default=None) + + +class KreaJobResult(BaseModel): + urls: list[str] | None = Field(default=None) + style_id: str | None = Field(default=None) + + +class KreaJob(BaseModel): + job_id: str = Field(...) + status: str = Field(...) + created_at: str = Field(...) + completed_at: str | None = Field(default=None) + result: KreaJobResult | None = Field(default=None) + + +class KreaAssetResponse(BaseModel): + id: str = Field(...) + image_url: str = Field(...) + uploaded_at: str = Field(...) + width: float | None = Field(default=None) + height: float | None = Field(default=None) + size_bytes: float | None = Field(default=None) + mime_type: str | None = Field(default=None) diff --git a/comfy_api_nodes/apis/stability.py b/comfy_api_nodes/apis/stability.py index 718360187..5b9b5ac7d 100644 --- a/comfy_api_nodes/apis/stability.py +++ b/comfy_api_nodes/apis/stability.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from enum import Enum from typing import Optional diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index e08fc0b01..1b4abafe6 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -2,11 +2,12 @@ import hashlib import logging import math import re +from io import BytesIO import torch from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api.latest import IO, ComfyExtension, Input, Types from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS, RECOMMENDED_PRESETS_SEEDREAM_4, @@ -308,6 +309,26 @@ async def _seedance_virtual_library_upload_image_asset( return f"asset://{create_resp.asset_id}" +async def _seedance_virtual_library_upload_video_asset( + cls: type[IO.ComfyNode], + video: Input.Video, + *, + wait_label: str = "Uploading video", +) -> str: + buf = BytesIO() + video.save_to(buf, format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264) + video_hash = hashlib.sha256(buf.getbuffer()).hexdigest() + public_url = await upload_video_to_comfyapi(cls, video, wait_label=wait_label) + create_resp = await sync_op( + cls, + ApiEndpoint(path="/proxy/seedance/virtual-library/assets", method="POST"), + response_model=SeedanceCreateAssetResponse, + data=SeedanceVirtualLibraryCreateAssetRequest(url=public_url, hash=video_hash, asset_type="Video"), + ) + await _wait_for_asset_active(cls, create_resp.asset_id, group_id="virtual-library") + return f"asset://{create_resp.asset_id}" + + def _seedance2_price_extractor(model_id: str, has_video_input: bool): """Returns a price_extractor closure for Seedance 2.0 poll_op.""" rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input)) @@ -2106,7 +2127,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode): content.append( TaskVideoContent( video_url=TaskVideoContentUrl( - url=await upload_video_to_comfyapi( + url=await _seedance_virtual_library_upload_video_asset( cls, reference_videos[key], wait_label=f"Uploading video {i}", diff --git a/comfy_api_nodes/nodes_krea.py b/comfy_api_nodes/nodes_krea.py new file mode 100644 index 000000000..003a8a654 --- /dev/null +++ b/comfy_api_nodes/nodes_krea.py @@ -0,0 +1,290 @@ +"""Krea image-generation nodes.""" + +import re + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.krea import ( + KreaAssetResponse, + KreaGenerateImageRequest, + KreaImageStyleReference, + KreaJob, + KreaMoodboard, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + poll_op, + sync_op, + tensor_to_bytesio, + validate_string, +) + + +class KreaIO: + STYLE_REF = "KREA_STYLE_REF" + + +async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Image) -> str: + """Upload an image to Krea's /assets endpoint and return the Krea-hosted image URL.""" + img_io = tensor_to_bytesio(image, total_pixels=2048 * 2048, mime_type="image/png") + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/krea/assets", method="POST"), + response_model=KreaAssetResponse, + files=[("file", (img_io.name, img_io, "image/png"))], + content_type="multipart/form-data", + max_retries=1, + wait_label="Uploading reference", + ) + return response.image_url + + +_MODEL_MEDIUM = "Krea 2 Medium" +_MODEL_LARGE = "Krea 2 Large" +_MODEL_ENDPOINTS: dict[str, str] = { + _MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium", + _MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large", +} + +_ASPECT_RATIOS = ["1:1", "4:3", "3:2", "16:9", "2.35:1", "4:5", "2:3", "9:16"] +_RESOLUTIONS = ["1K"] +_CREATIVITY_LEVELS = ["raw", "low", "medium", "high"] +_KREA_QUEUED_STATUSES = ["backlogged", "queued", "scheduled"] + +_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$") + + +def _krea_model_inputs() -> list: + """Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo.""" + return [ + IO.Combo.Input( + "aspect_ratio", + options=_ASPECT_RATIOS, + tooltip="Output aspect ratio.", + ), + IO.Combo.Input( + "resolution", + options=_RESOLUTIONS, + tooltip="Resolution scale.", + ), + IO.Combo.Input( + "creativity", + options=_CREATIVITY_LEVELS, + default="medium", + tooltip="Prompt interpretation strength: raw stays closest to the prompt; high is most creative.", + ), + IO.String.Input( + "moodboard_id", + default="", + tooltip="Optional Krea moodboard UUID (e.g. from the Krea website). " + "Leave empty to disable. Only one moodboard is supported per request.", + optional=True, + ), + IO.Float.Input( + "moodboard_strength", + default=0.35, + min=-0.5, + max=1.5, + step=0.05, + tooltip="Moodboard influence; ignored when moodboard_id is empty.", + optional=True, + ), + IO.Custom(KreaIO.STYLE_REF).Input( + "style_reference", + optional=True, + tooltip="Optional chain of style references (max 10) from Krea 2 Style Reference nodes.", + ), + ] + + +class Krea2ImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Krea2ImageNode", + display_name="Krea 2 Image", + category="api node/image/Krea", + description=( + "Generate images via Krea 2 — pick Medium (expressive illustrations) or " + "Large (expressive photorealism). Supports an optional moodboard and up " + "to 10 chained image style references." + ), + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the image.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()), + IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()), + ], + tooltip="Krea 2 Medium is best for expressive illustrations; " + "Krea 2 Large is best for expressive photorealism.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Random seed for reproducibility.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model", "model.moodboard_id"], + inputs=["model.style_reference"], + ), + expr=""" + ( + $isLarge := widgets.model = "krea 2 large"; + $hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0; + $hasStyle := $lookup(inputs, "model.style_reference").connected; + $usd := $hasMoodboard + ? ($isLarge ? 0.07 : 0.04) + : ($hasStyle + ? ($isLarge ? 0.065 : 0.035) + : ($isLarge ? 0.06 : 0.03)); + {"type":"usd","usd": $usd} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1) + + model_choice = model["model"] + endpoint_path = _MODEL_ENDPOINTS.get(model_choice) + if endpoint_path is None: + raise ValueError(f"Unknown Krea 2 model: {model_choice!r}") + + moodboards: list[KreaMoodboard] | None = None + mb_id = (model.get("moodboard_id") or "").strip() + if mb_id: + if not _UUID_RE.match(mb_id): + raise ValueError(f"moodboard_id must be a UUID (received {mb_id!r}); copy it from the Krea website.") + mb_strength = model.get("moodboard_strength") + moodboards = [KreaMoodboard(id=mb_id, strength=0.35 if mb_strength is None else float(mb_strength))] + + style_reference = model.get("style_reference") + image_style_references: list[KreaImageStyleReference] | None = None + if style_reference: + if len(style_reference) > 10: + raise ValueError(f"Krea 2 accepts at most 10 image_style_references; received {len(style_reference)}.") + image_style_references = [ + KreaImageStyleReference(url=ref["url"], strength=float(ref["strength"])) for ref in style_reference + ] + initial = await sync_op( + cls, + ApiEndpoint(path=endpoint_path, method="POST"), + response_model=KreaJob, + data=KreaGenerateImageRequest( + prompt=prompt, + aspect_ratio=model["aspect_ratio"], + resolution=model["resolution"], + seed=seed, + creativity=model["creativity"], + moodboards=moodboards, + image_style_references=image_style_references, + ), + ) + job = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/krea/jobs/{initial.job_id}", method="GET"), + response_model=KreaJob, + status_extractor=lambda r: r.status, + queued_statuses=_KREA_QUEUED_STATUSES, + ) + if not job.result or not job.result.urls: + raise RuntimeError(f"Krea 2 job {job.job_id} completed without any image URLs.") + image = await download_url_to_image_tensor(job.result.urls[0]) + return IO.NodeOutput(image) + + +class Krea2StyleReferenceNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Krea2StyleReferenceNode", + display_name="Krea 2 Style Reference", + category="api node/image/Krea", + description=( + "Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 " + "Style Reference nodes (max 10) and feed the final `style_reference` output " + "into Krea 2 Image. Each image is uploaded to ComfyAPI storage and passed as URL." + ), + inputs=[ + IO.Image.Input( + "image", + tooltip="Reference image whose style influences the generation.", + ), + IO.Float.Input( + "strength", + default=1.0, + min=-2.0, + max=2.0, + step=0.05, + tooltip="Reference strength; negative values invert the style influence.", + ), + IO.Custom(KreaIO.STYLE_REF).Input( + "style_reference", + optional=True, + tooltip="Optional incoming chain of style references; this node appends one more.", + ), + ], + outputs=[IO.Custom(KreaIO.STYLE_REF).Output(display_name="style_reference")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + strength: float, + style_reference: list[dict] | None = None, + ) -> IO.NodeOutput: + chain: list[dict] = list(style_reference) if style_reference else [] + if len(chain) >= 10: + raise ValueError("Krea 2 accepts at most 10 image_style_references in one generation.") + url = await _upload_image_to_krea_assets(cls, image) + chain.append({"url": url, "strength": float(strength)}) + return IO.NodeOutput(chain) + + +class KreaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + Krea2ImageNode, + Krea2StyleReferenceNode, + ] + + +async def comfy_entrypoint() -> KreaExtension: + return KreaExtension() diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index c47f3c79b..479ee8a53 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -1,4 +1,3 @@ -from __future__ import annotations from typing import Type, Literal import nodes diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py index f951a3350..731b8dc66 100644 --- a/comfy_execution/progress.py +++ b/comfy_execution/progress.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from typing import TypedDict, Dict, Optional, Tuple from typing_extensions import override from PIL import Image diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py index e73624bd1..ae9a2376c 100644 --- a/comfy_execution/validation.py +++ b/comfy_execution/validation.py @@ -1,4 +1,3 @@ -from __future__ import annotations from comfy_api.latest import IO diff --git a/comfy_extras/mediapipe/face_geometry.py b/comfy_extras/mediapipe/face_geometry.py index 04b2b0557..4f3813430 100644 --- a/comfy_extras/mediapipe/face_geometry.py +++ b/comfy_extras/mediapipe/face_geometry.py @@ -2,7 +2,6 @@ + weighted Procrustes solver. Computes the 4x4 facial transformation matrix. """ -from __future__ import annotations import math import numpy as np diff --git a/comfy_extras/mediapipe/face_landmarker.py b/comfy_extras/mediapipe/face_landmarker.py index 6a9a25f82..95c67b321 100644 --- a/comfy_extras/mediapipe/face_landmarker.py +++ b/comfy_extras/mediapipe/face_landmarker.py @@ -1,7 +1,6 @@ """Pure-PyTorch port of MediaPipe's face_landmarker_v2_with_blendshapes.task: BlazeFace detector → FaceMesh v2 → ARKit-52 blendshapes.""" -from __future__ import annotations import math from functools import lru_cache diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index d5084497e..f09a8a874 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import av import torchaudio import torch diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py index 4ebb4b51e..b585c560f 100644 --- a/comfy_extras/nodes_cfg.py +++ b/comfy_extras/nodes_cfg.py @@ -57,24 +57,55 @@ class CFGNorm(io.ComfyNode): inputs=[ io.Model.Input("model"), io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01), + io.Boolean.Input( + "pre_cfg", + default=False, + optional=True, + tooltip=( + "If true, rescale the combined noise BEFORE the sampler's CFG combine, " + "without clamping (can amplify). Matches the norm-scaled CFG used by " + "models like Lens. Default false keeps the original post-CFG x0-space " + "attenuate-only behavior." + ), + ), ], outputs=[io.Model.Output(display_name="patched_model")], is_experimental=True, ) @classmethod - def execute(cls, model, strength) -> io.NodeOutput: + def execute(cls, model, strength, pre_cfg=False) -> io.NodeOutput: m = model.clone() - def cfg_norm(args): - cond_p = args['cond_denoised'] - pred_text_ = args["denoised"] + if pre_cfg: + def cfg_norm_pre(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + comb = uncond + cond_scale * (cond - uncond) + cond_norm = torch.linalg.vector_norm(cond, dim=1, keepdim=True) + comb_norm = torch.linalg.vector_norm(comb, dim=1, keepdim=True) + rescale = torch.where( + comb_norm > 0, + cond_norm / comb_norm.clamp_min(1e-12), + torch.ones_like(comb_norm), + ) + rescaled = comb * rescale + # strength blends back toward standard linear CFG (1.0 = full rescale). + if strength != 1.0: + rescaled = strength * rescaled + (1.0 - strength) * comb + return rescaled + m.set_model_sampler_cfg_function(cfg_norm_pre) + else: + def cfg_norm(args): + cond_p = args['cond_denoised'] + pred_text_ = args["denoised"] - norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True) - norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True) - scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0) - return pred_text_ * scale * strength + norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True) + norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True) + scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0) + return pred_text_ * scale * strength - m.set_model_sampler_post_cfg_function(cfg_norm) + m.set_model_sampler_post_cfg_function(cfg_norm) return io.NodeOutput(m) diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index f7ca833dc..24729c3a7 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -1,4 +1,3 @@ -from __future__ import annotations from comfy_api.latest import ComfyExtension, io import comfy.context_windows import nodes diff --git a/comfy_extras/nodes_curve.py b/comfy_extras/nodes_curve.py index 9803e8034..099453131 100644 --- a/comfy_extras/nodes_curve.py +++ b/comfy_extras/nodes_curve.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import numpy as np from comfy_api.latest import ComfyExtension, io diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 33933229d..fe6008aa3 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import nodes import folder_paths diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 342cadb69..92507f1fc 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -1,4 +1,3 @@ -from __future__ import annotations from typing import TypedDict from typing_extensions import override from comfy_api.latest import ComfyExtension, io diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 51cf7951f..48d75c9e5 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -226,10 +226,20 @@ def get_noise_mask(latent): noise_mask = noise_mask.clone() return noise_mask -def get_keyframe_idxs(cond): +def get_keyframe_idxs(cond, latent_shape=None): keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) if keyframe_idxs is None: return None, 0 + # Get number of keyframes from latent_shape or guide_attention_entries if available + if latent_shape is not None and len(latent_shape) == 5: + tokens_per_frame = latent_shape[-2] * latent_shape[-1] + num_keyframes = keyframe_idxs.shape[2] // tokens_per_frame + return keyframe_idxs, num_keyframes + entries = conditioning_get_any_value(cond, "guide_attention_entries", None) + if entries: + num_keyframes = sum(e["latent_shape"][0] for e in entries) + return keyframe_idxs, num_keyframes + # fallback, may under-count if keyframes share t-start # keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0] return keyframe_idxs, num_keyframes @@ -322,9 +332,9 @@ class LTXVAddGuide(io.ComfyNode): return factor @classmethod - def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): + def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors, latent_shape=None): time_scale_factor, _, _ = scale_factors - _, num_keyframes = get_keyframe_idxs(cond) + _, num_keyframes = get_keyframe_idxs(cond, latent_shape) latent_count = latent_length - num_keyframes frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0) if guide_length > 1 and frame_idx != 0: @@ -436,7 +446,7 @@ class LTXVAddGuide(io.ComfyNode): num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1 resolved_frame_idx = frame_idx if frame_idx < 0: - _, num_keyframes = get_keyframe_idxs(positive) + _, num_keyframes = get_keyframe_idxs(positive, latent_image.shape) resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0) causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1 @@ -454,7 +464,7 @@ class LTXVAddGuide(io.ComfyNode): if latent_downscale_factor > 1: t, guide_mask = cls.dilate_latent(t, latent_downscale_factor) - frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) + frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors, latent_shape=latent_image.shape) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." positive, negative, latent_image, noise_mask = cls.append_keyframe( @@ -506,7 +516,7 @@ class LTXVCropGuides(io.ComfyNode): latent_image = latent["samples"].clone() noise_mask = get_noise_mask(latent) - _, num_keyframes = get_keyframe_idxs(positive) + _, num_keyframes = get_keyframe_idxs(positive, latent_image.shape) if num_keyframes == 0: return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py index 06aefa475..0040d1a92 100644 --- a/comfy_extras/nodes_math.py +++ b/comfy_extras/nodes_math.py @@ -4,7 +4,6 @@ Provides a ComfyMathExpression node that evaluates math expressions against dynamically-grown numeric inputs. """ -from __future__ import annotations import math import string diff --git a/comfy_extras/nodes_mediapipe.py b/comfy_extras/nodes_mediapipe.py index 6b7916aee..32dc22de3 100644 --- a/comfy_extras/nodes_mediapipe.py +++ b/comfy_extras/nodes_mediapipe.py @@ -10,7 +10,6 @@ Custom IO types: MediaPipeFaceLandmarker also emits the core BOUNDING_BOX type — pair with DrawBBoxes. """ -from __future__ import annotations import numpy as np import torch diff --git a/comfy_extras/nodes_moge.py b/comfy_extras/nodes_moge.py index 3508781a0..79aec5d7f 100644 --- a/comfy_extras/nodes_moge.py +++ b/comfy_extras/nodes_moge.py @@ -1,6 +1,5 @@ """ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration.""" -from __future__ import annotations import torch diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py new file mode 100644 index 000000000..d2f6fe67a --- /dev/null +++ b/comfy_extras/nodes_multigpu.py @@ -0,0 +1,408 @@ +from __future__ import annotations + +import copy +import logging +from inspect import cleandoc +from typing import TYPE_CHECKING +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + +if TYPE_CHECKING: + from comfy.model_patcher import ModelPatcher + from comfy.sd import CLIP, VAE +import torch + +import comfy.model_management +import comfy.multigpu + + +class MultiGPUCFGSplitNode(io.ComfyNode): + """ + Prepares model to have sampling accelerated via splitting work units. + + Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. + + Other than those exceptions, this node can be placed in any order. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MultiGPU_WorkUnits", + display_name="MultiGPU CFG Split", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Model.Input("model"), + io.Int.Input("max_gpus", default=2, min=1, step=1), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) + return io.NodeOutput(model) + + +def _force_supported_compute_dtype(patcher: ModelPatcher, device: torch.device): + """Cast compute dtype to one the device supports; no-op if already supported.""" + weight_dtype = patcher.model_dtype() + cast_dtype = comfy.model_management.unet_manual_cast(weight_dtype, device) + if cast_dtype is None: + return + logging.info(f"Select Model Device: using {cast_dtype} compute dtype on {device} (model weight dtype was {weight_dtype}).") + patcher.set_model_compute_dtype(cast_dtype) + + +def _remember_base_devices(patcher: ModelPatcher): + """Stash the original load/offload device on the underlying model. + + Stored on patcher.model (which is shared with the input patcher), so + later "default" selections can recover the loader's original routing. + Only the first Select on a given chain writes these attrs; subsequent + deepclones inherit them onto their freshly-loaded model below. + """ + if not hasattr(patcher.model, "_select_base_load_device"): + patcher.model._select_base_load_device = patcher.load_device + patcher.model._select_base_offload_device = patcher.offload_device + + +def _propagate_base_devices(src_model, dst_model): + """Carry the loader-original device attrs onto the freshly-deepcloned model.""" + if hasattr(src_model, "_select_base_load_device") and not hasattr(dst_model, "_select_base_load_device"): + dst_model._select_base_load_device = src_model._select_base_load_device + dst_model._select_base_offload_device = src_model._select_base_offload_device + + +def _retarget_patcher(patcher: ModelPatcher, target_load_device, target_offload_device): + """Return a patcher whose actual model weights live on *target_load_device*. + + If *patcher* is already on *target_load_device* we just retarget the + (already-cloned) patcher's metadata in place. Otherwise we call + :meth:`ModelPatcher.deepclone_multigpu` to spawn a fresh model from + the loader's ``cached_patcher_init`` factory -- the only safe way to + move weights that may already be partially loaded onto another device. + + NOTE: reusing the input patcher's model when the requested device + matches its current load_device is a deliberate fast path. Anything + that has already mutated the original model (e.g. a prior KSampler + invocation on the same model) will be observed here. This is by + design and documented on the SelectXDeviceNode docstrings -- placing + Select X Device after a node that consumes the same model is not + recommended. + """ + if patcher.load_device == target_load_device: + # Fast path: weights already on the desired device, just update offload. + patcher.offload_device = target_offload_device + return patcher + src_model = patcher.model + patcher = patcher.deepclone_multigpu(new_load_device=target_load_device) + patcher.offload_device = target_offload_device + _propagate_base_devices(src_model, patcher.model) + if hasattr(patcher, "register_load_device"): + patcher.register_load_device(patcher.load_device) + return patcher + + +def _apply_patcher_device(patcher: ModelPatcher, resolved, base_offload_override=None): + """Resolve the requested device and produce a patcher routed there. + + For "default" we restore the loader's original load/offload pair. + For CPU we pin both load and offload to CPU (and, on a dynamic + patcher, downgrade to a plain ModelPatcher so the dynamic-only + code paths are bypassed). + For an explicit GPU we keep the loader's original offload but + target the requested load device; if that differs from the current + load device the patcher is deepcloned onto the new device. + """ + _remember_base_devices(patcher) + base_load = patcher.model._select_base_load_device + base_offload = base_offload_override if base_offload_override is not None else patcher.model._select_base_offload_device + + if resolved is None: + # "default" -> route back to the loader's original devices. + return _retarget_patcher(patcher, base_load, base_offload) + if resolved.type == "cpu": + if patcher.is_dynamic(): + # clone(disable_dynamic=True) requires cached_patcher_init; let the + # exception surface to the caller (Select*DeviceNode.execute), which + # will translate it into a passthrough+log so unsupported loaders + # don't hard-fail the workflow. + patcher = patcher.clone(disable_dynamic=True) + patcher.load_device = resolved + patcher.offload_device = resolved + return patcher + return _retarget_patcher(patcher, resolved, base_offload) + + +def _prune_multigpu_collision(model: ModelPatcher, primary_device): + """Drop any multigpu clone whose load_device matches *primary_device*. + + Without pruning, MultiGPU CFG Split would have stacked a clone on + the same device the primary now occupies (i.e. the workflow places + MultiGPU CFG Split before Select Model Device). Keeps the clone set + consistent with the new primary placement. + """ + multigpu_models = model.get_additional_models_with_key("multigpu") + if not multigpu_models: + return + filtered = [m for m in multigpu_models if m.load_device != primary_device] + if len(filtered) != len(multigpu_models): + logging.info(f"Select Model Device: pruning MultiGPU clone on {primary_device} that now collides with the primary model.") + model.set_additional_models("multigpu", filtered) + if hasattr(model, "match_multigpu_clones"): + model.match_multigpu_clones() + + +class SelectModelDeviceNode(io.ComfyNode): + """ + Place the diffusion model on a specific device (default / cpu / gpu:N). + + - "default" restores the device assigned by the loader (even after a + prior Select Model Device call). + - "cpu" pins both the load and offload device to CPU. + - "gpu:N" pins the load device to the Nth available GPU; the offload + device is restored to the loader's original choice. + + When the requested device differs from the device the input model is + already on, a fresh model is spawned via the loader's reload factory + (cached_patcher_init) so the new patcher owns independent weights on + the new device. Loaders that don't support multigpu (no factory) will + cause the node to pass through unchanged with a warning. + + If the workflow already has MultiGPU CFG Split applied and the chosen + GPU collides with one of the existing multigpu clones, that clone is + dropped so two patchers don't end up bound to the same device. + + When the selected device does not exist on the current machine + (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), + the node passes the model through unchanged and logs a message + instead of failing. + + NOTE: Placing Select Model Device *after* a node that has already + consumed the same model (e.g. a KSampler that ran on this model on + the original device) is not recommended -- any state the prior + consumer mutated on the original model will be observed when the + selected device matches the original (fast path). Place Select Model + Device before any consumer of the model. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SelectModelDevice", + display_name="Select Model Device", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Model.Input("model"), + io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def validate_inputs(cls, device="default"): + # Allow unknown gpu:N values so portable workflows do not error + # at validation time; runtime fallback will handle them. + return True + + @classmethod + def execute(cls, model: ModelPatcher, device: str = "default") -> io.NodeOutput: + model = model.clone() + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is None and device not in (None, "default"): + logging.info(f"Select Model Device: requested device '{device}' not available, passing through unchanged.") + return io.NodeOutput(model) + try: + model = _apply_patcher_device(model, resolved) + except RuntimeError as e: + logging.warning(f"Select Model Device: cannot retarget model, passing through unchanged. ({e})") + return io.NodeOutput(model) + if resolved is not None: + _force_supported_compute_dtype(model, resolved) + _prune_multigpu_collision(model, model.load_device) + return io.NodeOutput(model) + + +class SelectCLIPDeviceNode(io.ComfyNode): + """ + Place the CLIP text encoder on a specific device (default / cpu / gpu:N). + + - "default" restores the device assigned by the loader. + - "cpu" pins both the load and offload device to CPU. + - "gpu:N" pins the load device to the Nth available GPU. + + When the selected device does not exist on the current machine + (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), + the node passes the CLIP through unchanged and logs a message + instead of failing. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SelectCLIPDevice", + display_name="Select CLIP Device", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Clip.Input("clip"), + io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options()), + ], + outputs=[ + io.Clip.Output(), + ], + ) + + @classmethod + def validate_inputs(cls, device="default"): + return True + + @classmethod + def execute(cls, clip: CLIP, device: str = "default") -> io.NodeOutput: + clip = clip.clone() + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is None and device not in (None, "default"): + logging.info(f"Select CLIP Device: requested device '{device}' not available, passing through unchanged.") + return io.NodeOutput(clip) + try: + clip.patcher = _apply_patcher_device(clip.patcher, resolved) + except RuntimeError as e: + logging.warning(f"Select CLIP Device: cannot retarget CLIP, passing through unchanged. ({e})") + return io.NodeOutput(clip) + + +class SelectVAEDeviceNode(io.ComfyNode): + """ + Place the VAE on a specific device (default / gpu:N). + + - "default" restores the device assigned by the loader. + - "gpu:N" pins the load device to the Nth available GPU; the offload + device is set to the standard VAE offload device. + + CPU is intentionally not exposed in the UI for the VAE; if a workflow + supplies "cpu" anyway (e.g. opened from another machine), the request + is dropped with a log message and the VAE is passed through unchanged. + + When the selected device does not exist on the current machine + (e.g. a workflow built on a 2-GPU box opened on a 1-GPU box), + the node passes the VAE through unchanged and logs a message + instead of failing. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SelectVAEDevice", + display_name="Select VAE Device", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Vae.Input("vae"), + io.Combo.Input("device", options=comfy.model_management.get_gpu_device_options_no_cpu()), + ], + outputs=[ + io.Vae.Output(), + ], + ) + + @classmethod + def validate_inputs(cls, device="default"): + return True + + @classmethod + def execute(cls, vae: VAE, device: str = "default") -> io.NodeOutput: + # VAE has no .clone(); shallow-copy the wrapper and clone the patcher + # so we can retarget load/offload device without affecting the input VAE. + vae = copy.copy(vae) + vae.patcher = vae.patcher.clone() + resolved = comfy.model_management.resolve_gpu_device_option(device) + if resolved is None and device not in (None, "default"): + logging.info(f"Select VAE Device: requested device '{device}' not available, passing through unchanged.") + return io.NodeOutput(vae) + if resolved is not None and resolved.type == "cpu": + logging.info("Select VAE Device: CPU is not a supported choice, passing through unchanged.") + return io.NodeOutput(vae) + if not hasattr(vae, "_select_base_device"): + vae._select_base_device = vae.device + try: + vae.patcher = _apply_patcher_device( + vae.patcher, resolved, + base_offload_override=comfy.model_management.vae_offload_device(), + ) + except RuntimeError as e: + logging.warning(f"Select VAE Device: cannot retarget VAE, passing through unchanged. ({e})") + return io.NodeOutput(vae) + # Keep VAE wrapper in sync with whatever model the patcher now owns; + # deepclone_multigpu may have produced a fresh first_stage_model. + vae.first_stage_model = vae.patcher.model + vae.device = vae._select_base_device if resolved is None else resolved + return io.NodeOutput(vae) + + +class MultiGPUOptionsNode(io.ComfyNode): + """ + Select the relative speed of GPUs in the special case they have significantly different performance from one another. + + NOTE (not registered yet, see MultiGPUExtension.get_node_list below): + The output GPUOptionsGroup is plumbed through create_multigpu_deepclones() and stored on + model.model_options['multigpu_options'] via GPUOptionsGroup.register(), but the cond + scheduler in comfy/samplers.py (calc_cond_batch_outer_multigpu) does NOT yet consult + relative_speed when distributing conds across devices; it uses a uniform conds_per_device + round-robin via next_available_device(). Before re-enabling this node, wire its + relative_speed into the scheduler (e.g. via comfy.multigpu.load_balance_devices(), + which already implements the proportional split) so the input actually affects work + distribution. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MultiGPU_Options", + display_name="MultiGPU Options", + category="advanced/multigpu", + description=cleandoc(cls.__doc__), + inputs=[ + io.Int.Input("device_index", default=0, min=0, max=64), + io.Float.Input("relative_speed", default=1.0, min=0.0, step=0.01), + io.Custom("GPU_OPTIONS").Input("gpu_options", optional=True), + ], + outputs=[ + io.Custom("GPU_OPTIONS").Output(), + ], + ) + + @classmethod + def execute(cls, device_index: int, relative_speed: float, gpu_options: comfy.multigpu.GPUOptionsGroup = None) -> io.NodeOutput: + if not gpu_options: + gpu_options = comfy.multigpu.GPUOptionsGroup() + else: + gpu_options = gpu_options.clone() + + opt = comfy.multigpu.GPUOptions(device_index=device_index, relative_speed=relative_speed) + gpu_options.add(opt) + + return io.NodeOutput(gpu_options) + + +class MultiGPUExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + MultiGPUCFGSplitNode, + SelectModelDeviceNode, + SelectCLIPDeviceNode, + SelectVAEDeviceNode, + # MultiGPUOptionsNode, + ] + + +async def comfy_entrypoint() -> MultiGPUExtension: + return MultiGPUExtension() diff --git a/comfy_extras/nodes_number_convert.py b/comfy_extras/nodes_number_convert.py index e38a33c15..01593b6e6 100644 --- a/comfy_extras/nodes_number_convert.py +++ b/comfy_extras/nodes_number_convert.py @@ -4,7 +4,6 @@ Provides a single node that converts INT, FLOAT, STRING, and BOOL inputs into FLOAT and INT outputs. """ -from __future__ import annotations import math diff --git a/comfy_extras/nodes_painter.py b/comfy_extras/nodes_painter.py index e104c8480..df7a0b76a 100644 --- a/comfy_extras/nodes_painter.py +++ b/comfy_extras/nodes_painter.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import hashlib import os diff --git a/comfy_extras/nodes_pid.py b/comfy_extras/nodes_pid.py new file mode 100644 index 000000000..811b9ae8e --- /dev/null +++ b/comfy_extras/nodes_pid.py @@ -0,0 +1,55 @@ +"""PiD (Pixel Diffusion Decoder) node""" + +import torch +from typing_extensions import override + +import node_helpers +import comfy.latent_formats +from comfy_api.latest import ComfyExtension, io + + +class PiDConditioning(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="PiDConditioning", + display_name="PiD Conditioning", + category="advanced/conditioning", + description=( + "Attaches a latent and a degrade_sigma scalar to a CONDITIONING for PiD decoding/upscaling" + ), + inputs=[ + io.Conditioning.Input("positive"), + io.Latent.Input("latent", tooltip="latent (from VAEEncode or a KSampler)."), + io.Combo.Input("latent_format", options=["flux", "sd3"], default="flux", + tooltip="Flux1 and Flux2 latents auto-detected from channel dim, sd3 has to be selected manually."), + io.Float.Input( + "degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01, + tooltip="0 = clean latent. Increase to denoise corrupted latent outputs.", + ), + ], + outputs=[io.Conditioning.Output()], + ) + + @classmethod + def execute(cls, positive, latent, latent_format: str, degrade_sigma: float) -> io.NodeOutput: + samples = latent["samples"] + if latent_format == "flux": + fmt_cls = comfy.latent_formats.Flux2 if samples.shape[1] == 128 else comfy.latent_formats.Flux + else: + fmt_cls = comfy.latent_formats.SD3 + lq_latent = fmt_cls().process_in(samples) + sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32) + return io.NodeOutput(node_helpers.conditioning_set_values( + positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t}, + )) + + +class PiDExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [PiDConditioning] + + +async def comfy_entrypoint() -> PiDExtension: + return PiDExtension() diff --git a/comfy_extras/nodes_resolution.py b/comfy_extras/nodes_resolution.py index 520b4067e..1628038cc 100644 --- a/comfy_extras/nodes_resolution.py +++ b/comfy_extras/nodes_resolution.py @@ -1,4 +1,3 @@ -from __future__ import annotations import math from enum import Enum from typing_extensions import override diff --git a/comfy_extras/nodes_toolkit.py b/comfy_extras/nodes_toolkit.py index ae802896b..0548a0cf8 100644 --- a/comfy_extras/nodes_toolkit.py +++ b/comfy_extras/nodes_toolkit.py @@ -1,4 +1,3 @@ -from __future__ import annotations from typing_extensions import override from comfy_api.latest import ComfyExtension, io diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 78a2a28f8..ae1d826d5 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import os import av import torch diff --git a/folder_paths.py b/folder_paths.py index 36d61fcd0..7304e1b73 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import os import time import mimetypes diff --git a/main.py b/main.py index 26d523c30..bce451a83 100644 --- a/main.py +++ b/main.py @@ -218,7 +218,7 @@ import comfy.model_patcher if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()): if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") - elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + elif comfy_aimdo.control.init_devices(d.index for d in comfy.model_management.get_all_torch_devices()): if args.verbose == 'DEBUG': comfy_aimdo.control.set_log_debug() elif args.verbose == 'CRITICAL': diff --git a/nodes.py b/nodes.py index 820eeef4c..155ecfe9c 100644 --- a/nodes.py +++ b/nodes.py @@ -1,4 +1,3 @@ -from __future__ import annotations import torch @@ -795,6 +794,7 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): metadata = None + vae_path = None if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) @@ -813,6 +813,14 @@ class VAELoader: metadata["tae_latent_channels"] = 128 vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae.throw_exception_if_invalid() + # Register a reload factory on the patcher so multigpu deepclones + # (Select VAE Device, future MultiGPU VAE work-units) can produce + # per-device clones from the same loader context. Only set when we + # actually have a single backing file -- pixel_space and the + # image TAESDs (composed from separate encoder/decoder files via + # load_taesd) are not addressable by a single vae_path. + if vae_path is not None: + vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, None)) return (vae,) class ControlNetLoader: @@ -961,7 +969,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -971,7 +979,7 @@ class CLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b\n pixeldit: gemma 2 2B elm" def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -2389,6 +2397,7 @@ async def init_builtin_extra_nodes(): "nodes_lt_audio.py", "nodes_lt.py", "nodes_hooks.py", + "nodes_multigpu.py", "nodes_load_3d.py", "nodes_cosmos.py", "nodes_video.py", @@ -2411,6 +2420,7 @@ async def init_builtin_extra_nodes(): "nodes_context_windows.py", "nodes_qwen.py", "nodes_chroma_radiance.py", + "nodes_pid.py", "nodes_model_patch.py", "nodes_easycache.py", "nodes_audio_encoder.py", diff --git a/openapi.yaml b/openapi.yaml index 502e518c7..f801a39d9 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -275,7 +275,10 @@ paths: responses: "200": description: Queue updated - + content: + application/json: + schema: + $ref: "#/components/schemas/QueueManageResponse" '400': description: Invalid request parameters content: @@ -3092,18 +3095,34 @@ paths: application/json: schema: type: object - required: - - asset_ids properties: + job_ids: + type: array + items: + type: string + description: Job IDs whose associated assets should all be included in the ZIP bundle. asset_ids: type: array items: type: string format: uuid - description: IDs of assets to export + description: Asset IDs to include in the ZIP bundle. Additive to assets associated with provided job IDs. export_name: type: string description: Name for the export archive + naming_strategy: + type: string + enum: [group_by_job_id, preserve, asset_id, group_by_job_time] + default: group_by_job_time + description: "Strategy for naming files in the ZIP: group by job ID, preserve original names, use the asset ID, or group by job creation time." + job_asset_name_filters: + type: object + additionalProperties: + type: array + minItems: 1 + items: + type: string + description: Optional per-job asset name filters. When provided for a job ID, only assets whose name matches one of the listed names are included. responses: "202": description: Export task accepted @@ -3575,10 +3594,7 @@ paths: content: application/json: schema: - type: array - items: - $ref: "#/components/schemas/HubLabel" - + $ref: "#/components/schemas/HubLabelListResponse" '400': description: Bad request (e.g. invalid type parameter) content: @@ -7466,6 +7482,25 @@ components: type: string description: Array of prompt IDs to delete from queue + QueueManageResponse: + type: object + x-runtime: [cloud] + description: >- + [cloud-only] Result of a queue mutation. The Cloud runtime returns which + items were deleted and whether the queue was cleared; local ComfyUI + returns an empty 200 body. + properties: + deleted: + type: array + nullable: true + items: + type: string + description: Prompt IDs that were deleted from the queue. + cleared: + type: boolean + nullable: true + description: Whether the queue was cleared. + # ------------------------------------------------------------------- # History # ------------------------------------------------------------------- @@ -7546,6 +7581,16 @@ components: outputs_count: type: integer description: Total number of output files + workflow_id: + type: string + nullable: true + x-runtime: [cloud] + description: "[cloud-only] UUID of the Cloud workflow entity this job is associated with. Local ComfyUI returns null." + execution_error: + x-runtime: [cloud] + description: "[cloud-only] Detailed execution error from ComfyUI for failed jobs. Absent on local ComfyUI." + allOf: + - $ref: "#/components/schemas/ExecutionError" JobDetailResponse: type: object @@ -10433,6 +10478,19 @@ components: - custom_node description: Label category. + HubLabelListResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Response wrapper for the available Hub label catalog.' + required: + - labels + properties: + labels: + type: array + items: + $ref: '#/components/schemas/HubLabelInfo' + description: Available labels, optionally filtered by type. + HubProfileSummary: type: object x-runtime: [cloud] diff --git a/requirements.txt b/requirements.txt index 2ca6d8929..9308e29d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.44.19 -comfyui-workflow-templates==0.9.82 +comfyui-workflow-templates==0.9.85 comfyui-embedded-docs==0.5.1 torch torchsde diff --git a/server.py b/server.py index 44470b904..268441bd1 100644 --- a/server.py +++ b/server.py @@ -646,18 +646,37 @@ class PromptServer(): @routes.get("/system_stats") async def system_stats(request): - device = comfy.model_management.get_torch_device() - device_name = comfy.model_management.get_torch_device_name(device) + primary_device = comfy.model_management.get_torch_device() cpu_device = comfy.model_management.torch.device("cpu") ram_total = comfy.model_management.get_total_memory(cpu_device) ram_free = comfy.model_management.get_free_memory(cpu_device) - vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) - vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) required_frontend_version = FrontendManager.get_required_frontend_version() installed_templates_version = FrontendManager.get_installed_templates_version() required_templates_version = FrontendManager.get_required_templates_version() comfy_package_versions = FrontendManager.get_comfy_package_versions() + # Report every torch device visible to multigpu, with the primary + # device first so existing clients that read devices[0] keep working. + torch_devices = comfy.model_management.get_all_torch_devices() + if primary_device in torch_devices: + torch_devices = [primary_device] + [d for d in torch_devices if d != primary_device] + else: + torch_devices = [primary_device] + list(torch_devices) + + device_entries = [] + for d in torch_devices: + vram_total, torch_vram_total = comfy.model_management.get_total_memory(d, torch_total_too=True) + vram_free, torch_vram_free = comfy.model_management.get_free_memory(d, torch_free_too=True) + device_entries.append({ + "name": comfy.model_management.get_torch_device_name(d), + "type": d.type, + "index": d.index, + "vram_total": vram_total, + "vram_free": vram_free, + "torch_vram_total": torch_vram_total, + "torch_vram_free": torch_vram_free, + }) + system_stats = { "system": { "os": sys.platform, @@ -673,17 +692,7 @@ class PromptServer(): "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", "argv": sys.argv }, - "devices": [ - { - "name": device_name, - "type": device.type, - "index": device.index, - "vram_total": vram_total, - "vram_free": vram_free, - "torch_vram_total": torch_vram_total, - "torch_vram_free": torch_vram_free, - } - ] + "devices": device_entries } return web.json_response(system_stats)