mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
- Experimental support for sage attention on Linux - Diffusers loader now supports model indices - Transformers model management now aligns with updates to ComfyUI - Flux layers correctly use unbind - Add float8 support for model loading in more places - Experimental quantization approaches from Quanto and torchao - Model upscaling interacts with memory management better This update also disables ROCm testing because it isn't reliable enough on consumer hardware. ROCm is not really supported by the 7600.
252 lines
11 KiB
Python
252 lines
11 KiB
Python
import math
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
from torch import Tensor, nn
|
|
|
|
from .math import attention, rope
|
|
from ..common_dit import rms_norm
|
|
|
|
|
|
class EmbedND(nn.Module):
|
|
def __init__(self, dim: int, theta: int, axes_dim: list):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.theta = theta
|
|
self.axes_dim = axes_dim
|
|
|
|
def forward(self, ids: Tensor) -> Tensor:
|
|
n_axes = ids.shape[-1]
|
|
emb = torch.cat(
|
|
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
|
dim=-3,
|
|
)
|
|
|
|
return emb.unsqueeze(1)
|
|
|
|
|
|
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
|
"""
|
|
Create sinusoidal timestep embeddings.
|
|
:param t: a 1-D Tensor of N indices, one per batch element.
|
|
These may be fractional.
|
|
:param dim: the dimension of the output.
|
|
:param max_period: controls the minimum frequency of the embeddings.
|
|
:return: an (N, D) Tensor of positional embeddings.
|
|
"""
|
|
t = time_factor * t
|
|
half = dim // 2
|
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
|
|
|
args = t[:, None].float() * freqs[None]
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
if dim % 2:
|
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
if torch.is_floating_point(t):
|
|
embedding = embedding.to(t)
|
|
return embedding
|
|
|
|
class MLPEmbedder(nn.Module):
|
|
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
|
self.silu = nn.SiLU()
|
|
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return self.out_layer(self.silu(self.in_layer(x)))
|
|
|
|
|
|
class RMSNorm(torch.nn.Module):
|
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
|
|
|
def forward(self, x: Tensor):
|
|
return rms_norm(x, self.scale, 1e-6)
|
|
|
|
|
|
class QKNorm(torch.nn.Module):
|
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
|
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
|
q = self.query_norm(q)
|
|
k = self.key_norm(k)
|
|
return q.to(v), k.to(v)
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
head_dim = dim // num_heads
|
|
|
|
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
|
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
|
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
|
|
|
|
|
@dataclass
|
|
class ModulationOut:
|
|
shift: Tensor
|
|
scale: Tensor
|
|
gate: Tensor
|
|
|
|
|
|
class Modulation(nn.Module):
|
|
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.is_double = double
|
|
self.multiplier = 6 if double else 3
|
|
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(self, vec: Tensor) -> tuple:
|
|
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
|
|
|
return (
|
|
ModulationOut(*out[:3]),
|
|
ModulationOut(*out[3:]) if self.is_double else None,
|
|
)
|
|
|
|
|
|
class DoubleStreamBlock(nn.Module):
|
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
|
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
self.num_heads = num_heads
|
|
self.hidden_size = hidden_size
|
|
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
|
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.img_mlp = nn.Sequential(
|
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
nn.GELU(approximate="tanh"),
|
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
)
|
|
|
|
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
|
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.txt_mlp = nn.Sequential(
|
|
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
|
nn.GELU(approximate="tanh"),
|
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
|
)
|
|
|
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
|
img_mod1, img_mod2 = self.img_mod(vec)
|
|
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
|
|
|
# prepare image for attention
|
|
img_modulated = self.img_norm1(img)
|
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
|
img_qkv = self.img_attn.qkv(img_modulated)
|
|
img_qkv = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
img_q, img_k, img_v = torch.unbind(img_qkv, dim=0)
|
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
|
|
|
# prepare txt for attention
|
|
txt_modulated = self.txt_norm1(txt)
|
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
|
txt_qkv = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
txt_q, txt_k, txt_v = torch.unbind(txt_qkv, dim=0)
|
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
|
|
|
# run actual attention
|
|
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
|
torch.cat((txt_k, img_k), dim=2),
|
|
torch.cat((txt_v, img_v), dim=2), pe=pe)
|
|
|
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
|
|
|
# calculate the img bloks
|
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
|
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
|
|
|
# calculate the txt bloks
|
|
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
|
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
|
|
|
if txt.dtype == torch.float16:
|
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
|
|
|
return img, txt
|
|
|
|
|
|
class SingleStreamBlock(nn.Module):
|
|
"""
|
|
A DiT block with parallel linear layers as described in
|
|
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
mlp_ratio: float = 4.0,
|
|
qk_scale: float = None,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None
|
|
):
|
|
super().__init__()
|
|
self.hidden_dim = hidden_size
|
|
self.num_heads = num_heads
|
|
head_dim = hidden_size // num_heads
|
|
self.scale = qk_scale or head_dim**-0.5
|
|
|
|
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
|
# qkv and mlp_in
|
|
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
|
# proj and mlp_out
|
|
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
|
|
|
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.hidden_size = hidden_size
|
|
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
|
|
self.mlp_act = nn.GELU(approximate="tanh")
|
|
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
|
|
|
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
|
mod, _ = self.modulation(vec)
|
|
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
|
|
|
qkv = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
q, k, v = torch.unbind(qkv, dim=0)
|
|
q, k = self.norm(q, k, v)
|
|
|
|
# compute attention
|
|
attn = attention(q, k, v, pe=pe)
|
|
# compute activation in mlp stream, cat again and run second linear layer
|
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
|
x += mod.gate * output
|
|
if x.dtype == torch.float16:
|
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
|
return x
|
|
|
|
|
|
class LastLayer(nn.Module):
|
|
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
|
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
|
|
|
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
|
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
|
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
|
x = self.linear(x)
|
|
return x
|