diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 63cc150d4..f9ee79199 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -7,6 +7,16 @@ from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.flux.layers import EmbedND +def attention(q, k, v, heads, transformer_options={}): + return optimized_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + heads=heads, + skip_reshape=True, + transformer_options=transformer_options + ) + def apply_scale_shift_norm(norm, x, scale, shift): return torch.addcmul(shift, norm(x), scale + 1.0) @@ -23,6 +33,7 @@ def get_freqs(dim, max_period=10000.0): * torch.arange(start=0, end=dim, dtype=torch.float32) / dim) + class TimeEmbeddings(nn.Module): def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None): super().__init__() @@ -31,13 +42,13 @@ class TimeEmbeddings(nn.Module): self.max_period = max_period self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False) operations = operation_settings.get("operations") - self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32) self.activation = nn.SiLU() - self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32) def forward(self, timestep): args = torch.outer(timestep, self.freqs.to(device=timestep.device)) - time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(torch.bfloat16) #todo dtype + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) return time_embed @@ -81,17 +92,18 @@ class Modulation(nn.Module): def __init__(self, time_dim, model_dim, num_params, operation_settings=None): super().__init__() self.activation = nn.SiLU() - operations = operation_settings.get("operations") - self.out_layer = operations.Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=torch.float32) def forward(self, x): - return self.out_layer(self.activation(x)) + return self.out_layer(self.activation(x.float())).to(x.dtype) + class SelfAttention(nn.Module): def __init__(self, num_channels, head_dim, operation_settings=None): super().__init__() assert num_channels % head_dim == 0 self.num_heads = num_channels // head_dim + self.head_dim = head_dim operations = operation_settings.get("operations") self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) @@ -103,73 +115,29 @@ class SelfAttention(nn.Module): self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) def get_qkv(self, x): - q = self.to_query(x) - k = self.to_key(x) - v = self.to_value(x) - - shape = q.shape[:-1] - q = q.view(*shape, self.num_heads, -1) - k = k.view(*shape, self.num_heads, -1) - v = v.view(*shape, self.num_heads, -1) - + q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1) + k = self.to_key(x).view(*x.shape[:-1], self.num_heads, -1) + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) return q, k, v def forward(self, x, freqs, transformer_options={}): q, k, v = self.get_qkv(x) - q = apply_rope1(self.query_norm(q), freqs) k = apply_rope1(self.key_norm(k), freqs) - - out = optimized_attention( - q.flatten(-2, -1), - k.flatten(-2, -1), - v.flatten(-2, -1), - heads=self.num_heads, - transformer_options=transformer_options - ) + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) return self.out_layer(out) -class CrossAttention(nn.Module): - def __init__(self, num_channels, head_dim, operation_settings=None): - super().__init__() - assert num_channels % head_dim == 0 - self.num_heads = num_channels // head_dim - - operations = operation_settings.get("operations") - self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - - self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - - def get_qkv(self, x, cond): - q = self.to_query(x) - k = self.to_key(cond) - v = self.to_value(cond) - - shape, cond_shape = q.shape[:-1], k.shape[:-1] - q = q.view(*shape, self.num_heads, -1) - k = k.view(*cond_shape, self.num_heads, -1) - v = v.view(*cond_shape, self.num_heads, -1) - +class CrossAttention(SelfAttention): + def get_qkv(self, x, context): + q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1) + k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1) + v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1) return q, k, v - def forward(self, x, cond, transformer_options={}): - q, k, v = self.get_qkv(x, cond) - q = self.query_norm(q) - k = self.key_norm(k) - - out = optimized_attention( - q.flatten(-2, -1), - k.flatten(-2, -1), - v.flatten(-2, -1), - heads=self.num_heads, - transformer_options=transformer_options - ) - + def forward(self, x, context, transformer_options={}): + q, k, v = self.get_qkv(x, context) + out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options) return self.out_layer(out) @@ -210,6 +178,7 @@ class OutLayer(nn.Module): ) return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5) + class TransformerEncoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None): super().__init__() @@ -371,7 +340,7 @@ class Kandinsky5(nn.Module): def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs): context = self.text_embeddings(context) - time_embed = self.time_embeddings(timestep) + self.pooled_text_embeddings(y) + time_embed = self.time_embeddings(timestep).to(x.dtype) + self.pooled_text_embeddings(y) for block in self.text_transformer_blocks: context = block(context, time_embed, freqs_text, transformer_options=transformer_options) @@ -392,7 +361,6 @@ class Kandinsky5(nn.Module): freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options) return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs) - def forward(self, x, timestep, context, y, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward,