Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-10-23 00:19:39 +03:00 committed by GitHub
commit fd143ca944
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 10 deletions

View File

@ -5,7 +5,7 @@ from typing import Dict, Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from .. import attention from ..attention import optimized_attention
from einops import rearrange, repeat from einops import rearrange, repeat
from .util import timestep_embedding from .util import timestep_embedding
import comfy.ops import comfy.ops
@ -266,8 +266,6 @@ def split_qkv(qkv, head_dim):
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0) qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)
return qkv[0], qkv[1], qkv[2] return qkv[0], qkv[1], qkv[2]
def optimized_attention(qkv, num_heads):
return attention.optimized_attention(qkv[0], qkv[1], qkv[2], num_heads)
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug") ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug")
@ -326,9 +324,9 @@ class SelfAttention(nn.Module):
return x return x
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
qkv = self.pre_attention(x) q, k, v = self.pre_attention(x)
x = optimized_attention( x = optimized_attention(
qkv, num_heads=self.num_heads q, k, v, heads=self.num_heads
) )
x = self.post_attention(x) x = self.post_attention(x)
return x return x
@ -531,8 +529,8 @@ class DismantledBlock(nn.Module):
assert not self.pre_only assert not self.pre_only
qkv, intermediates = self.pre_attention(x, c) qkv, intermediates = self.pre_attention(x, c)
attn = optimized_attention( attn = optimized_attention(
qkv, qkv[0], qkv[1], qkv[2],
num_heads=self.attn.num_heads, heads=self.attn.num_heads,
) )
return self.post_attention(attn, *intermediates) return self.post_attention(attn, *intermediates)
@ -557,8 +555,8 @@ def _block_mixing(context, x, context_block, x_block, c):
qkv = tuple(o) qkv = tuple(o)
attn = optimized_attention( attn = optimized_attention(
qkv, qkv[0], qkv[1], qkv[2],
num_heads=x_block.attn.num_heads, heads=x_block.attn.num_heads,
) )
context_attn, x_attn = ( context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]], attn[:, : context_qkv[0].shape[1]],
@ -642,7 +640,7 @@ class SelfAttentionContext(nn.Module):
def forward(self, x): def forward(self, x):
qkv = self.qkv(x) qkv = self.qkv(x)
q, k, v = split_qkv(qkv, self.dim_head) q, k, v = split_qkv(qkv, self.dim_head)
x = optimized_attention((q.reshape(q.shape[0], q.shape[1], -1), k, v), self.heads) x = optimized_attention(q.reshape(q.shape[0], q.shape[1], -1), k, v, heads=self.heads)
return self.proj(x) return self.proj(x)
class ContextProcessorBlock(nn.Module): class ContextProcessorBlock(nn.Module):

View File

@ -264,10 +264,14 @@ def fp8_linear(self, input):
scale_input = self.scale_input scale_input = self.scale_input
if scale_weight is None: if scale_weight is None:
scale_weight = torch.ones((), device=input.device, dtype=torch.float32) scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
else:
scale_weight = scale_weight.to(input.device)
if scale_input is None: if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = input.reshape(-1, input.shape[2]).to(dtype) inn = input.reshape(-1, input.shape[2]).to(dtype)
else: else:
scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype) inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
if bias is not None: if bias is not None: