mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-11 00:37:53 +08:00
Fix MPS upscale/VAE tensor layout and attention chunking issues
This commit is contained in:
parent
6648ab68bc
commit
151e2f274f
@ -259,7 +259,10 @@ def slice_attention(q, k, v):
|
||||
|
||||
while True:
|
||||
try:
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
slice_size = max(1, math.ceil(q.shape[1] / steps))
|
||||
if q.device.type == "mps":
|
||||
max_slice_for_mps = max(1, ((2 ** 31) - 1) // max(1, q.shape[0] * k.shape[2]))
|
||||
slice_size = min(slice_size, max_slice_for_mps)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||
@ -270,6 +273,15 @@ def slice_attention(q, k, v):
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except RuntimeError as e:
|
||||
if q.device.type == "mps" and ("INT_MAX" in str(e) or "MPSGraph" in str(e) or "MPSGaph" in str(e)):
|
||||
model_management.soft_empty_cache(True)
|
||||
steps *= 2
|
||||
if steps > 4096:
|
||||
raise e
|
||||
logging.warning("MPS attention limit reached, increasing steps and trying again {}".format(steps))
|
||||
continue
|
||||
raise e
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
model_management.soft_empty_cache(True)
|
||||
steps *= 2
|
||||
|
||||
@ -1105,7 +1105,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
|
||||
# handle entire input fitting in a single tile
|
||||
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
||||
output[b:b+1] = function(s).to(output_device)
|
||||
output[b:b+1] = function(s.contiguous()).to(output_device)
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
continue
|
||||
@ -1125,7 +1125,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
s_in = s_in.narrow(d + 2, pos, l)
|
||||
upscaled.append(round(get_pos(d, pos)))
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
ps = function(s_in.contiguous()).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
|
||||
for d in range(2, dims + 2):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user