mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
Code cleanup, optimizations, use fp32 for all layers originally at fp32
This commit is contained in:
parent
0920cdcf63
commit
8f02217f85
@ -7,6 +7,16 @@ from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
def attention(q, k, v, heads, transformer_options={}):
|
||||
return optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
heads=heads,
|
||||
skip_reshape=True,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return torch.addcmul(shift, norm(x), scale + 1.0)
|
||||
|
||||
@ -23,6 +33,7 @@ def get_freqs(dim, max_period=10000.0):
|
||||
* torch.arange(start=0, end=dim, dtype=torch.float32)
|
||||
/ dim)
|
||||
|
||||
|
||||
class TimeEmbeddings(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
||||
super().__init__()
|
||||
@ -31,13 +42,13 @@ class TimeEmbeddings(nn.Module):
|
||||
self.max_period = max_period
|
||||
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32)
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32)
|
||||
|
||||
def forward(self, timestep):
|
||||
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(torch.bfloat16) #todo dtype
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
|
||||
@ -81,17 +92,18 @@ class Modulation(nn.Module):
|
||||
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
||||
super().__init__()
|
||||
self.activation = nn.SiLU()
|
||||
operations = operation_settings.get("operations")
|
||||
self.out_layer = operations.Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=torch.float32)
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
return self.out_layer(self.activation(x.float())).to(x.dtype)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, num_channels, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
self.head_dim = head_dim
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
@ -103,73 +115,29 @@ class SelfAttention(nn.Module):
|
||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def get_qkv(self, x):
|
||||
q = self.to_query(x)
|
||||
k = self.to_key(x)
|
||||
v = self.to_value(x)
|
||||
|
||||
shape = q.shape[:-1]
|
||||
q = q.view(*shape, self.num_heads, -1)
|
||||
k = k.view(*shape, self.num_heads, -1)
|
||||
v = v.view(*shape, self.num_heads, -1)
|
||||
|
||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x)
|
||||
|
||||
q = apply_rope1(self.query_norm(q), freqs)
|
||||
k = apply_rope1(self.key_norm(k), freqs)
|
||||
|
||||
out = optimized_attention(
|
||||
q.flatten(-2, -1),
|
||||
k.flatten(-2, -1),
|
||||
v.flatten(-2, -1),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, num_channels, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def get_qkv(self, x, cond):
|
||||
q = self.to_query(x)
|
||||
k = self.to_key(cond)
|
||||
v = self.to_value(cond)
|
||||
|
||||
shape, cond_shape = q.shape[:-1], k.shape[:-1]
|
||||
q = q.view(*shape, self.num_heads, -1)
|
||||
k = k.view(*cond_shape, self.num_heads, -1)
|
||||
v = v.view(*cond_shape, self.num_heads, -1)
|
||||
|
||||
class CrossAttention(SelfAttention):
|
||||
def get_qkv(self, x, context):
|
||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, cond, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x, cond)
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
|
||||
out = optimized_attention(
|
||||
q.flatten(-2, -1),
|
||||
k.flatten(-2, -1),
|
||||
v.flatten(-2, -1),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
|
||||
def forward(self, x, context, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x, context)
|
||||
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
|
||||
@ -210,6 +178,7 @@ class OutLayer(nn.Module):
|
||||
)
|
||||
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
||||
|
||||
|
||||
class TransformerEncoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
@ -371,7 +340,7 @@ class Kandinsky5(nn.Module):
|
||||
|
||||
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||
context = self.text_embeddings(context)
|
||||
time_embed = self.time_embeddings(timestep) + self.pooled_text_embeddings(y)
|
||||
time_embed = self.time_embeddings(timestep).to(x.dtype) + self.pooled_text_embeddings(y)
|
||||
|
||||
for block in self.text_transformer_blocks:
|
||||
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||
@ -392,7 +361,6 @@ class Kandinsky5(nn.Module):
|
||||
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||
|
||||
|
||||
def forward(self, x, timestep, context, y, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user