mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 05:52:33 +08:00
bugfix: fix typo in apply_directory for custom_nodes_directory
allow for PATH style ';' delimited custom_node directories.
change delimiter type for seperate folders per platform.
feat(API-nodes): move Rodin3D nodes to new client; removed old api client.py (#10645)
Fix qwen controlnet regression. (#10657)
Enable pinned memory by default on Nvidia. (#10656)
Removed the --fast pinned_memory flag.
You can use --disable-pinned-memory to disable it. Please report if it
causes any issues.
Pinned mem also seems to work on AMD. (#10658)
Remove environment variable.
Removed environment variable fallback for custom nodes directory.
Update documentation for custom nodes directory
Clarified documentation on custom nodes directory argument, removed documentation on environment variable
Clarify release cycle. (#10667)
Tell users they need to upload their logs in bug reports. (#10671)
mm: guard against double pin and unpin explicitly (#10672)
As commented, if you let cuda be the one to detect double pin/unpinning
it actually creates an asyc GPU error.
Only unpin tensor if it was pinned by ComfyUI (#10677)
Make ScaleROPE node work on Flux. (#10686)
Add logging for model unloading. (#10692)
Unload weights if vram usage goes up between runs. (#10690)
ops: Put weight cast on the offload stream (#10697)
This needs to be on the offload stream. This reproduced a black screen
with low resolution images on a slow bus when using FP8.
Update CI workflow to remove dead macOS runner. (#10704)
* Update CI workflow to remove dead macOS runner.
* revert
* revert
Don't pin tensor if not a torch.nn.parameter.Parameter (#10718)
Update README.md for Intel Arc GPU installation, remove IPEX (#10729)
IPEX is no longer needed for Intel Arc GPUs. Removing instruction to setup ipex.
mm/mp: always unload re-used but modified models (#10724)
The partial unloader path in model re-use flow skips straight to the
actual unload without any check of the patching UUID. This means that
if you do an upscale flow with a model patch on an existing model, it
will not apply your patchings.
Fix by delaying the partial_unload until after the uuid checks. This
is done by making partial_unload a model of partial_load where extra_mem
is -ve.
qwen: reduce VRAM usage (#10725)
Clean up a bunch of stacked and no-longer-needed tensors on the QWEN
VRAM peak (currently FFN).
With this I go from OOMing at B=37x1328x1328 to being able to
succesfully run B=47 (RTX5090).
Update Python 3.14 compatibility notes in README (#10730)
Quantized Ops fixes (#10715)
* offload support, bug fixes, remove mixins
* add readme
add PR template for API-Nodes (#10736)
feat: add create_time dict to prompt field in /history and /queue (#10741)
flux: reduce VRAM usage (#10737)
Cleanup a bunch of stack tensors on Flux. This take me from B=19 to B=22
for 1600x1600 on RTX5090.
Better instructions for the portable. (#10743)
Use same code for chroma and flux blocks so that optimizations are shared. (#10746)
Fix custom nodes import error. (#10747)
This should fix the import errors but will break if the custom nodes actually try to use the class.
revert import reordering
revert imports pt 2
Add left padding support to tokenizers. (#10753)
chore(api-nodes): mark OpenAIDalle2 and OpenAIDalle3 nodes as deprecated (#10757)
Revert "chore(api-nodes): mark OpenAIDalle2 and OpenAIDalle3 nodes as deprecated (#10757)" (#10759)
This reverts commit 9a02382568.
Change ROCm nightly install command to 7.1 (#10764)
315 lines
13 KiB
Python
315 lines
13 KiB
Python
import math
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
from torch import Tensor, nn
|
|
|
|
from .math import attention, rope
|
|
import comfy.ops
|
|
import comfy.ldm.common_dit
|
|
|
|
|
|
class EmbedND(nn.Module):
|
|
def __init__(self, dim: int, theta: int, axes_dim: list):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.theta = theta
|
|
self.axes_dim = axes_dim
|
|
|
|
def forward(self, ids: Tensor) -> Tensor:
|
|
n_axes = ids.shape[-1]
|
|
emb = torch.cat(
|
|
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
|
dim=-3,
|
|
)
|
|
|
|
return emb.unsqueeze(1)
|
|
|
|
|
|
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
|
"""
|
|
Create sinusoidal timestep embeddings.
|
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
|
These may be fractional.
|
|
:param dim: the dimension of the output.
|
|
:param max_period: controls the minimum frequency of the embeddings.
|
|
:return: an (N, D) Tensor of positional embeddings.
|
|
"""
|
|
t = time_factor * t
|
|
half = dim // 2
|
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
|
|
|
args = t[:, None].float() * freqs[None]
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
if dim % 2:
|
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
if torch.is_floating_point(t):
|
|
embedding = embedding.to(t)
|
|
return embedding
|
|
|
|
class MLPEmbedder(nn.Module):
|
|
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
|
self.silu = nn.SiLU()
|
|
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self.out_layer(self.silu(self.in_layer(x)))
|
|
|
|
|
|
class RMSNorm(torch.nn.Module):
|
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
|
|
|
def forward(self, x: Tensor):
|
|
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
|
|
|
|
|
class QKNorm(torch.nn.Module):
|
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
|
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
|
q = self.query_norm(q)
|
|
k = self.key_norm(k)
|
|
return q.to(v), k.to(v)
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
|
|
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
|
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
|
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
|
|
|
|
|
@dataclass
|
|
class ModulationOut:
|
|
shift: Tensor
|
|
scale: Tensor
|
|
gate: Tensor
|
|
|
|
|
|
class Modulation(nn.Module):
|
|
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.is_double = double
|
|
self.multiplier = 6 if double else 3
|
|
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(self, vec: Tensor) -> tuple:
|
|
if vec.ndim == 2:
|
|
vec = vec[:, None, :]
|
|
out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
|
|
|
|
return (
|
|
ModulationOut(*out[:3]),
|
|
ModulationOut(*out[3:]) if self.is_double else None,
|
|
)
|
|
|
|
|
|
def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
|
if modulation_dims is None:
|
|
if m_add is not None:
|
|
return torch.addcmul(m_add, tensor, m_mult)
|
|
else:
|
|
return tensor * m_mult
|
|
else:
|
|
for d in modulation_dims:
|
|
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
|
|
if m_add is not None:
|
|
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
|
|
return tensor
|
|
|
|
|
|
class DoubleStreamBlock(nn.Module):
|
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
|
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
self.num_heads = num_heads
|
|
self.hidden_size = hidden_size
|
|
self.modulation = modulation
|
|
|
|
if self.modulation:
|
|
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.img_mlp = nn.Sequential(
|
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
nn.GELU(approximate="tanh"),
|
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
)
|
|
|
|
if self.modulation:
|
|
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.txt_mlp = nn.Sequential(
|
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
nn.GELU(approximate="tanh"),
|
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
)
|
|
self.flipped_img_txt = flipped_img_txt
|
|
|
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
|
if self.modulation:
|
|
img_mod1, img_mod2 = self.img_mod(vec)
|
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
|
else:
|
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
|
|
|
# prepare image for attention
|
|
img_modulated = self.img_norm1(img)
|
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
|
img_qkv = self.img_attn.qkv(img_modulated)
|
|
del img_modulated
|
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
del img_qkv
|
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
|
|
|
# prepare txt for attention
|
|
txt_modulated = self.txt_norm1(txt)
|
|
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
|
del txt_modulated
|
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
del txt_qkv
|
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
|
|
|
if self.flipped_img_txt:
|
|
q = torch.cat((img_q, txt_q), dim=2)
|
|
del img_q, txt_q
|
|
k = torch.cat((img_k, txt_k), dim=2)
|
|
del img_k, txt_k
|
|
v = torch.cat((img_v, txt_v), dim=2)
|
|
del img_v, txt_v
|
|
# run actual attention
|
|
attn = attention(q, k, v,
|
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
del q, k, v
|
|
|
|
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
|
else:
|
|
q = torch.cat((txt_q, img_q), dim=2)
|
|
del txt_q, img_q
|
|
k = torch.cat((txt_k, img_k), dim=2)
|
|
del txt_k, img_k
|
|
v = torch.cat((txt_v, img_v), dim=2)
|
|
del txt_v, img_v
|
|
# run actual attention
|
|
attn = attention(q, k, v,
|
|
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
del q, k, v
|
|
|
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
|
|
|
# calculate the img bloks
|
|
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
|
del img_attn
|
|
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
|
|
|
# calculate the txt bloks
|
|
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
|
del txt_attn
|
|
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
|
|
|
if txt.dtype == torch.float16:
|
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
|
|
|
return img, txt
|
|
|
|
|
|
class SingleStreamBlock(nn.Module):
|
|
"""
|
|
A DiT block with parallel linear layers as described in
|
|
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
mlp_ratio: float = 4.0,
|
|
qk_scale: float = None,
|
|
modulation=True,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
self.hidden_dim = hidden_size
|
|
self.num_heads = num_heads
|
|
head_dim = hidden_size // num_heads
|
|
self.scale = qk_scale or head_dim**-0.5
|
|
|
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
# qkv and mlp_in
|
|
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
|
# proj and mlp_out
|
|
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
|
|
|
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.hidden_size = hidden_size
|
|
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
|
|
self.mlp_act = nn.GELU(approximate="tanh")
|
|
if modulation:
|
|
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
|
else:
|
|
self.modulation = None
|
|
|
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
|
if self.modulation:
|
|
mod, _ = self.modulation(vec)
|
|
else:
|
|
mod = vec
|
|
|
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
|
|
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
del qkv
|
|
q, k = self.norm(q, k, v)
|
|
|
|
# compute attention
|
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
del q, k, v
|
|
# compute activation in mlp stream, cat again and run second linear layer
|
|
mlp = self.mlp_act(mlp)
|
|
output = self.linear2(torch.cat((attn, mlp), 2))
|
|
x += apply_mod(output, mod.gate, None, modulation_dims)
|
|
if x.dtype == torch.float16:
|
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
|
return x
|
|
|
|
|
|
class LastLayer(nn.Module):
|
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
|
|
|
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
|
if vec.ndim == 2:
|
|
vec = vec[:, None, :]
|
|
|
|
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
|
|
x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
|
|
x = self.linear(x)
|
|
return x
|