Fix MPS upscale/VAE tensor layout and attention chunking issues

This commit is contained in:
fritzprix 2026-02-14 14:38:30 +09:00
parent 6648ab68bc
commit 151e2f274f
2 changed files with 15 additions and 3 deletions

View File

@ -259,7 +259,10 @@ def slice_attention(q, k, v):
while True:
try:
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
slice_size = max(1, math.ceil(q.shape[1] / steps))
if q.device.type == "mps":
max_slice_for_mps = max(1, ((2 ** 31) - 1) // max(1, q.shape[0] * k.shape[2]))
slice_size = min(slice_size, max_slice_for_mps)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = torch.bmm(q[:, i:end], k) * scale
@ -270,6 +273,15 @@ def slice_attention(q, k, v):
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except RuntimeError as e:
if q.device.type == "mps" and ("INT_MAX" in str(e) or "MPSGraph" in str(e) or "MPSGaph" in str(e)):
model_management.soft_empty_cache(True)
steps *= 2
if steps > 4096:
raise e
logging.warning("MPS attention limit reached, increasing steps and trying again {}".format(steps))
continue
raise e
except model_management.OOM_EXCEPTION as e:
model_management.soft_empty_cache(True)
steps *= 2

View File

@ -1105,7 +1105,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
# handle entire input fitting in a single tile
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
output[b:b+1] = function(s).to(output_device)
output[b:b+1] = function(s.contiguous()).to(output_device)
if pbar is not None:
pbar.update(1)
continue
@ -1125,7 +1125,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device)
ps = function(s_in.contiguous()).to(output_device)
mask = torch.ones_like(ps)
for d in range(2, dims + 2):