From f91078b1ffa484c424f78814f54de4d5846e4daa Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 13 Nov 2025 20:05:26 +0200 Subject: [PATCH 1/6] add PR template for API-Nodes (#10736) --- .github/PULL_REQUEST_TEMPLATE/api-node.md | 21 ++++++++ .github/workflows/api-node-template.yml | 58 +++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE/api-node.md create mode 100644 .github/workflows/api-node-template.yml diff --git a/.github/PULL_REQUEST_TEMPLATE/api-node.md b/.github/PULL_REQUEST_TEMPLATE/api-node.md new file mode 100644 index 000000000..f62744878 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/api-node.md @@ -0,0 +1,21 @@ + + +## API Node PR Checklist + +### Scope +- [ ] **Is API Node Change** + +### Pricing & Billing +- [ ] **Need pricing update** +- [ ] **No pricing update** + +If **Need pricing update**: +- [ ] Metronome rate cards updated +- [ ] Auto‑billing tests updated and passing + +### QA +- [ ] **QA done** +- [ ] **QA not required** + +### Comms +- [ ] Informed **@Kosinkadink** diff --git a/.github/workflows/api-node-template.yml b/.github/workflows/api-node-template.yml new file mode 100644 index 000000000..0775f9979 --- /dev/null +++ b/.github/workflows/api-node-template.yml @@ -0,0 +1,58 @@ +name: Append API Node PR template + +on: + pull_request_target: + types: [opened, reopened, synchronize, edited, ready_for_review] + paths: + - 'comfy_api_nodes/**' # only run if these files changed + +permissions: + contents: read + pull-requests: write + +jobs: + inject: + runs-on: ubuntu-latest + steps: + - name: Ensure template exists and append to PR body + uses: actions/github-script@v7 + with: + script: | + const { owner, repo } = context.repo; + const number = context.payload.pull_request.number; + const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md'; + const marker = ''; + + const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number }); + + let templateText; + try { + const res = await github.rest.repos.getContent({ + owner, + repo, + path: templatePath, + ref: pr.base.ref + }); + const buf = Buffer.from(res.data.content, res.data.encoding || 'base64'); + templateText = buf.toString('utf8'); + } catch (e) { + core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`); + return; + } + + // Enforce the presence of the marker inside the template (for idempotence) + if (!templateText.includes(marker)) { + core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`); + return; + } + + // If the PR already contains the marker, do not append again. + const body = pr.body || ''; + if (body.includes(marker)) { + core.info('Template already present in PR body; nothing to inject.'); + return; + } + + const newBody = (body ? body + '\n\n' : '') + templateText + '\n'; + await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody }); + core.notice('API Node template appended to PR description.'); From 2fde9597f4b02c5f06c1a5ceb3ca2fa6d74966ec Mon Sep 17 00:00:00 2001 From: ric-yu Date: Thu, 13 Nov 2025 15:11:52 -0800 Subject: [PATCH 2/6] feat: add create_time dict to prompt field in /history and /queue (#10741) --- server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server.py b/server.py index 5d773b10a..d059d3dc9 100644 --- a/server.py +++ b/server.py @@ -2,6 +2,7 @@ import os import sys import asyncio import traceback +import time import nodes import folder_paths @@ -733,6 +734,7 @@ class PromptServer(): for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS: if sensitive_val in extra_data: sensitive[sensitive_val] = extra_data.pop(sensitive_val) + extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) From 94c298f9625b0fd9af8ea07a73075fdefe0d9e57 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 14 Nov 2025 10:02:03 +1000 Subject: [PATCH 3/6] flux: reduce VRAM usage (#10737) Cleanup a bunch of stack tensors on Flux. This take me from B=19 to B=22 for 1600x1600 on RTX5090. --- comfy/ldm/flux/layers.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index a3eab0470..f4bf56e01 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -167,39 +167,55 @@ class DoubleStreamBlock(nn.Module): img_modulated = self.img_norm1(img) img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) img_qkv = self.img_attn.qkv(img_modulated) + del img_modulated img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + del img_qkv img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention txt_modulated = self.txt_norm1(txt) txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt) txt_qkv = self.txt_attn.qkv(txt_modulated) + del txt_modulated txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + del txt_qkv txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) if self.flipped_img_txt: + q = torch.cat((img_q, txt_q), dim=2) + del img_q, txt_q + k = torch.cat((img_k, txt_k), dim=2) + del img_k, txt_k + v = torch.cat((img_v, txt_v), dim=2) + del img_v, txt_v # run actual attention - attn = attention(torch.cat((img_q, txt_q), dim=2), - torch.cat((img_k, txt_k), dim=2), - torch.cat((img_v, txt_v), dim=2), + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] else: + q = torch.cat((txt_q, img_q), dim=2) + del txt_q, img_q + k = torch.cat((txt_k, img_k), dim=2) + del txt_k, img_k + v = torch.cat((txt_v, img_v), dim=2) + del txt_v, img_v # run actual attention - attn = attention(torch.cat((txt_q, img_q), dim=2), - torch.cat((txt_k, img_k), dim=2), - torch.cat((txt_v, img_v), dim=2), + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) + del img_attn img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) # calculate the txt bloks txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt) + del txt_attn txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt) if txt.dtype == torch.float16: @@ -249,12 +265,15 @@ class SingleStreamBlock(nn.Module): qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + del qkv q, k = self.norm(q, k, v) # compute attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + mlp = self.mlp_act(mlp) + output = self.linear2(torch.cat((attn, mlp), 2)) x += apply_mod(output, mod.gate, None, modulation_dims) if x.dtype == torch.float16: x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) From 1ef328c007a419c2c429df0f80532cc11579dc97 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 13 Nov 2025 18:32:39 -0800 Subject: [PATCH 4/6] Better instructions for the portable. (#10743) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f51807ad5..cd8273b0d 100644 --- a/README.md +++ b/README.md @@ -173,7 +173,7 @@ There is a portable standalone build for Windows that should work for running on ### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) -Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints +Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\ If you have trouble extracting it, right click the file -> properties -> unblock From f60923590c3f2fd05e166e2ec57968aaf7007dd0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 13 Nov 2025 22:28:05 -0800 Subject: [PATCH 5/6] Use same code for chroma and flux blocks so that optimizations are shared. (#10746) --- comfy/ldm/chroma/layers.py | 121 ----------------------------- comfy/ldm/chroma/model.py | 7 +- comfy/ldm/chroma_radiance/model.py | 7 +- comfy/ldm/flux/layers.py | 31 ++++++-- 4 files changed, 31 insertions(+), 135 deletions(-) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index fc7110cce..9f4ad5bd2 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -1,12 +1,9 @@ import torch from torch import Tensor, nn -from comfy.ldm.flux.math import attention from comfy.ldm.flux.layers import ( MLPEmbedder, RMSNorm, - QKNorm, - SelfAttention, ModulationOut, ) @@ -48,124 +45,6 @@ class Approximator(nn.Module): return x -class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): - super().__init__() - - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.num_heads = num_heads - self.hidden_size = hidden_size - self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) - - self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) - - self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) - - self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) - self.flipped_img_txt = flipped_img_txt - - def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}): - (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec - - # prepare image for attention - img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img)) - img_qkv = self.img_attn.qkv(img_modulated) - img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) - - # prepare txt for attention - txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt)) - txt_qkv = self.txt_attn.qkv(txt_modulated) - txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) - - # run actual attention - attn = attention(torch.cat((txt_q, img_q), dim=2), - torch.cat((txt_k, img_k), dim=2), - torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask, transformer_options=transformer_options) - - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - - # calculate the img bloks - img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn)) - img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img)))) - - # calculate the txt bloks - txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn)) - txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt)))) - - if txt.dtype == torch.float16: - txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) - - return img, txt - - -class SingleStreamBlock(nn.Module): - """ - A DiT block with parallel linear layers as described in - https://arxiv.org/abs/2302.05442 and adapted modulation interface. - """ - - def __init__( - self, - hidden_size: int, - num_heads: int, - mlp_ratio: float = 4.0, - qk_scale: float = None, - dtype=None, - device=None, - operations=None - ): - super().__init__() - self.hidden_dim = hidden_size - self.num_heads = num_heads - head_dim = hidden_size // num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.mlp_hidden_dim = int(hidden_size * mlp_ratio) - # qkv and mlp_in - self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) - # proj and mlp_out - self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) - - self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) - - self.hidden_size = hidden_size - self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - - self.mlp_act = nn.GELU(approximate="tanh") - - def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor: - mod = vec - x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) - qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) - - q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k = self.norm(q, k, v) - - # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) - # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x.addcmul_(mod.gate, output) - if x.dtype == torch.float16: - x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) - return x - - class LastLayer(nn.Module): def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): super().__init__() diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index ad1c523fe..67bf70eb1 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -11,12 +11,12 @@ import comfy.ldm.common_dit from comfy.ldm.flux.layers import ( EmbedND, timestep_embedding, + DoubleStreamBlock, + SingleStreamBlock, ) from .layers import ( - DoubleStreamBlock, LastLayer, - SingleStreamBlock, Approximator, ChromaModulationOut, ) @@ -90,6 +90,7 @@ class Chroma(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=False, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -98,7 +99,7 @@ class Chroma(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 7d7be80f5..e643b4414 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -10,12 +10,10 @@ from torch import Tensor, nn from einops import repeat import comfy.ldm.common_dit -from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock from comfy.ldm.chroma.model import Chroma, ChromaParams from comfy.ldm.chroma.layers import ( - DoubleStreamBlock, - SingleStreamBlock, Approximator, ) from .layers import ( @@ -89,7 +87,6 @@ class ChromaRadiance(Chroma): dtype=dtype, device=device, operations=operations ) - self.double_blocks = nn.ModuleList( [ DoubleStreamBlock( @@ -97,6 +94,7 @@ class ChromaRadiance(Chroma): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=False, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -109,6 +107,7 @@ class ChromaRadiance(Chroma): self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, + modulation=False, dtype=dtype, device=device, operations=operations, ) for _ in range(params.depth_single_blocks) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index f4bf56e01..23150a712 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -130,13 +130,17 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) self.num_heads = num_heads self.hidden_size = hidden_size - self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.modulation = modulation + + if self.modulation: + self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) @@ -147,7 +151,9 @@ class DoubleStreamBlock(nn.Module): operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) - self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + if self.modulation: + self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) @@ -160,8 +166,11 @@ class DoubleStreamBlock(nn.Module): self.flipped_img_txt = flipped_img_txt def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): - img_mod1, img_mod2 = self.img_mod(vec) - txt_mod1, txt_mod2 = self.txt_mod(vec) + if self.modulation: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + else: + (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec # prepare image for attention img_modulated = self.img_norm1(img) @@ -236,6 +245,7 @@ class SingleStreamBlock(nn.Module): num_heads: int, mlp_ratio: float = 4.0, qk_scale: float = None, + modulation=True, dtype=None, device=None, operations=None @@ -258,10 +268,17 @@ class SingleStreamBlock(nn.Module): self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.mlp_act = nn.GELU(approximate="tanh") - self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) + if modulation: + self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) + else: + self.modulation = None def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor: - mod, _ = self.modulation(vec) + if self.modulation: + mod, _ = self.modulation(vec) + else: + mod = vec + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) From 443056c401c53953bb8eee6da71b9ad29afe2581 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 14 Nov 2025 00:26:05 -0800 Subject: [PATCH 6/6] Fix custom nodes import error. (#10747) This should fix the import errors but will break if the custom nodes actually try to use the class. --- comfy/ldm/chroma/layers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 9f4ad5bd2..2d5684348 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -7,6 +7,9 @@ from comfy.ldm.flux.layers import ( ModulationOut, ) +# TODO: remove this in a few months +SingleStreamBlock = None +DoubleStreamBlock = None class ChromaModulationOut(ModulationOut):