mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 12:20:16 +08:00
Merge branch 'comfyanonymous:master' into feature/preview-latent
This commit is contained in:
commit
ea351c6f67
@ -146,6 +146,41 @@ class ResnetBlock(nn.Module):
|
|||||||
|
|
||||||
return x+h
|
return x+h
|
||||||
|
|
||||||
|
def slice_attention(q, k, v):
|
||||||
|
r1 = torch.zeros_like(k, device=q.device)
|
||||||
|
scale = (int(q.shape[-1])**(-0.5))
|
||||||
|
|
||||||
|
mem_free_total = model_management.get_free_memory(q.device)
|
||||||
|
|
||||||
|
gb = 1024 ** 3
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||||
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
|
mem_required = tensor_size * modifier
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||||
|
|
||||||
|
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||||
|
del s2
|
||||||
|
break
|
||||||
|
except model_management.OOM_EXCEPTION as e:
|
||||||
|
steps *= 2
|
||||||
|
if steps > 128:
|
||||||
|
raise e
|
||||||
|
print("out of memory error, increasing steps and trying again", steps)
|
||||||
|
|
||||||
|
return r1
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels):
|
||||||
@ -183,48 +218,15 @@ class AttnBlock(nn.Module):
|
|||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
b,c,h,w = q.shape
|
b,c,h,w = q.shape
|
||||||
scale = (int(c)**(-0.5))
|
|
||||||
|
|
||||||
q = q.reshape(b,c,h*w)
|
q = q.reshape(b,c,h*w)
|
||||||
q = q.permute(0,2,1) # b,hw,c
|
q = q.permute(0,2,1) # b,hw,c
|
||||||
k = k.reshape(b,c,h*w) # b,c,hw
|
k = k.reshape(b,c,h*w) # b,c,hw
|
||||||
v = v.reshape(b,c,h*w)
|
v = v.reshape(b,c,h*w)
|
||||||
|
|
||||||
r1 = torch.zeros_like(k, device=q.device)
|
r1 = slice_attention(q, k, v)
|
||||||
|
|
||||||
mem_free_total = model_management.get_free_memory(q.device)
|
|
||||||
|
|
||||||
gb = 1024 ** 3
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
|
||||||
mem_required = tensor_size * modifier
|
|
||||||
steps = 1
|
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
|
||||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
|
||||||
|
|
||||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
|
||||||
del s1
|
|
||||||
|
|
||||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
|
||||||
del s2
|
|
||||||
break
|
|
||||||
except model_management.OOM_EXCEPTION as e:
|
|
||||||
steps *= 2
|
|
||||||
if steps > 128:
|
|
||||||
raise e
|
|
||||||
print("out of memory error, increasing steps and trying again", steps)
|
|
||||||
|
|
||||||
h_ = r1.reshape(b,c,h,w)
|
h_ = r1.reshape(b,c,h,w)
|
||||||
del r1
|
del r1
|
||||||
|
|
||||||
h_ = self.proj_out(h_)
|
h_ = self.proj_out(h_)
|
||||||
|
|
||||||
return x+h_
|
return x+h_
|
||||||
@ -331,25 +333,18 @@ class MemoryEfficientAttnBlockPytorch(nn.Module):
|
|||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
B, C, H, W = q.shape
|
B, C, H, W = q.shape
|
||||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
|
|
||||||
|
|
||||||
q, k, v = map(
|
q, k, v = map(
|
||||||
lambda t: t.unsqueeze(3)
|
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||||
.reshape(B, t.shape[1], 1, C)
|
|
||||||
.permute(0, 2, 1, 3)
|
|
||||||
.reshape(B * 1, t.shape[1], C)
|
|
||||||
.contiguous(),
|
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
|
||||||
|
|
||||||
out = (
|
try:
|
||||||
out.unsqueeze(0)
|
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
.reshape(B, 1, out.shape[1], C)
|
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||||
.permute(0, 2, 1, 3)
|
except model_management.OOM_EXCEPTION as e:
|
||||||
.reshape(B, out.shape[1], C)
|
print("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
)
|
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
|
||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x+out
|
return x+out
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user