mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 04:40:15 +08:00
Code cleanup, optimizations, use fp32 for all layers originally at fp32
This commit is contained in:
parent
0920cdcf63
commit
8f02217f85
@ -7,6 +7,16 @@ from comfy.ldm.modules.attention import optimized_attention
|
|||||||
from comfy.ldm.flux.math import apply_rope1
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
|
||||||
|
def attention(q, k, v, heads, transformer_options={}):
|
||||||
|
return optimized_attention(
|
||||||
|
q.transpose(1, 2),
|
||||||
|
k.transpose(1, 2),
|
||||||
|
v.transpose(1, 2),
|
||||||
|
heads=heads,
|
||||||
|
skip_reshape=True,
|
||||||
|
transformer_options=transformer_options
|
||||||
|
)
|
||||||
|
|
||||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||||
return torch.addcmul(shift, norm(x), scale + 1.0)
|
return torch.addcmul(shift, norm(x), scale + 1.0)
|
||||||
|
|
||||||
@ -23,6 +33,7 @@ def get_freqs(dim, max_period=10000.0):
|
|||||||
* torch.arange(start=0, end=dim, dtype=torch.float32)
|
* torch.arange(start=0, end=dim, dtype=torch.float32)
|
||||||
/ dim)
|
/ dim)
|
||||||
|
|
||||||
|
|
||||||
class TimeEmbeddings(nn.Module):
|
class TimeEmbeddings(nn.Module):
|
||||||
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -31,13 +42,13 @@ class TimeEmbeddings(nn.Module):
|
|||||||
self.max_period = max_period
|
self.max_period = max_period
|
||||||
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
||||||
operations = operation_settings.get("operations")
|
operations = operation_settings.get("operations")
|
||||||
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32)
|
||||||
self.activation = nn.SiLU()
|
self.activation = nn.SiLU()
|
||||||
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32)
|
||||||
|
|
||||||
def forward(self, timestep):
|
def forward(self, timestep):
|
||||||
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
||||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(torch.bfloat16) #todo dtype
|
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||||
return time_embed
|
return time_embed
|
||||||
|
|
||||||
@ -81,17 +92,18 @@ class Modulation(nn.Module):
|
|||||||
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.activation = nn.SiLU()
|
self.activation = nn.SiLU()
|
||||||
operations = operation_settings.get("operations")
|
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=torch.float32)
|
||||||
self.out_layer = operations.Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.out_layer(self.activation(x))
|
return self.out_layer(self.activation(x.float())).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
def __init__(self, num_channels, head_dim, operation_settings=None):
|
def __init__(self, num_channels, head_dim, operation_settings=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert num_channels % head_dim == 0
|
assert num_channels % head_dim == 0
|
||||||
self.num_heads = num_channels // head_dim
|
self.num_heads = num_channels // head_dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
operations = operation_settings.get("operations")
|
operations = operation_settings.get("operations")
|
||||||
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
@ -103,73 +115,29 @@ class SelfAttention(nn.Module):
|
|||||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
def get_qkv(self, x):
|
def get_qkv(self, x):
|
||||||
q = self.to_query(x)
|
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
k = self.to_key(x)
|
k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
v = self.to_value(x)
|
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
|
|
||||||
shape = q.shape[:-1]
|
|
||||||
q = q.view(*shape, self.num_heads, -1)
|
|
||||||
k = k.view(*shape, self.num_heads, -1)
|
|
||||||
v = v.view(*shape, self.num_heads, -1)
|
|
||||||
|
|
||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
def forward(self, x, freqs, transformer_options={}):
|
def forward(self, x, freqs, transformer_options={}):
|
||||||
q, k, v = self.get_qkv(x)
|
q, k, v = self.get_qkv(x)
|
||||||
|
|
||||||
q = apply_rope1(self.query_norm(q), freqs)
|
q = apply_rope1(self.query_norm(q), freqs)
|
||||||
k = apply_rope1(self.key_norm(k), freqs)
|
k = apply_rope1(self.key_norm(k), freqs)
|
||||||
|
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||||
out = optimized_attention(
|
|
||||||
q.flatten(-2, -1),
|
|
||||||
k.flatten(-2, -1),
|
|
||||||
v.flatten(-2, -1),
|
|
||||||
heads=self.num_heads,
|
|
||||||
transformer_options=transformer_options
|
|
||||||
)
|
|
||||||
return self.out_layer(out)
|
return self.out_layer(out)
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(SelfAttention):
|
||||||
def __init__(self, num_channels, head_dim, operation_settings=None):
|
def get_qkv(self, x, context):
|
||||||
super().__init__()
|
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
assert num_channels % head_dim == 0
|
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||||
self.num_heads = num_channels // head_dim
|
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||||
|
|
||||||
operations = operation_settings.get("operations")
|
|
||||||
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
||||||
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
||||||
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
||||||
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
||||||
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
||||||
|
|
||||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
||||||
|
|
||||||
def get_qkv(self, x, cond):
|
|
||||||
q = self.to_query(x)
|
|
||||||
k = self.to_key(cond)
|
|
||||||
v = self.to_value(cond)
|
|
||||||
|
|
||||||
shape, cond_shape = q.shape[:-1], k.shape[:-1]
|
|
||||||
q = q.view(*shape, self.num_heads, -1)
|
|
||||||
k = k.view(*cond_shape, self.num_heads, -1)
|
|
||||||
v = v.view(*cond_shape, self.num_heads, -1)
|
|
||||||
|
|
||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
def forward(self, x, cond, transformer_options={}):
|
def forward(self, x, context, transformer_options={}):
|
||||||
q, k, v = self.get_qkv(x, cond)
|
q, k, v = self.get_qkv(x, context)
|
||||||
q = self.query_norm(q)
|
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
|
||||||
k = self.key_norm(k)
|
|
||||||
|
|
||||||
out = optimized_attention(
|
|
||||||
q.flatten(-2, -1),
|
|
||||||
k.flatten(-2, -1),
|
|
||||||
v.flatten(-2, -1),
|
|
||||||
heads=self.num_heads,
|
|
||||||
transformer_options=transformer_options
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.out_layer(out)
|
return self.out_layer(out)
|
||||||
|
|
||||||
|
|
||||||
@ -210,6 +178,7 @@ class OutLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderBlock(nn.Module):
|
class TransformerEncoderBlock(nn.Module):
|
||||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -371,7 +340,7 @@ class Kandinsky5(nn.Module):
|
|||||||
|
|
||||||
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||||
context = self.text_embeddings(context)
|
context = self.text_embeddings(context)
|
||||||
time_embed = self.time_embeddings(timestep) + self.pooled_text_embeddings(y)
|
time_embed = self.time_embeddings(timestep).to(x.dtype) + self.pooled_text_embeddings(y)
|
||||||
|
|
||||||
for block in self.text_transformer_blocks:
|
for block in self.text_transformer_blocks:
|
||||||
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||||
@ -392,7 +361,6 @@ class Kandinsky5(nn.Module):
|
|||||||
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||||
return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user