diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5a22ef030..e81bb37d3 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -259,7 +259,10 @@ def slice_attention(q, k, v): while True: try: - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + slice_size = max(1, math.ceil(q.shape[1] / steps)) + if q.device.type == "mps": + max_slice_for_mps = max(1, ((2 ** 31) - 1) // max(1, q.shape[0] * k.shape[2])) + slice_size = min(slice_size, max_slice_for_mps) for i in range(0, q.shape[1], slice_size): end = i + slice_size s1 = torch.bmm(q[:, i:end], k) * scale @@ -270,6 +273,15 @@ def slice_attention(q, k, v): r1[:, :, i:end] = torch.bmm(v, s2) del s2 break + except RuntimeError as e: + if q.device.type == "mps" and ("INT_MAX" in str(e) or "MPSGraph" in str(e) or "MPSGaph" in str(e)): + model_management.soft_empty_cache(True) + steps *= 2 + if steps > 4096: + raise e + logging.warning("MPS attention limit reached, increasing steps and trying again {}".format(steps)) + continue + raise e except model_management.OOM_EXCEPTION as e: model_management.soft_empty_cache(True) steps *= 2 diff --git a/comfy/utils.py b/comfy/utils.py index 1337e2205..ecc2aa20f 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1105,7 +1105,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am # handle entire input fitting in a single tile if all(s.shape[d+2] <= tile[d] for d in range(dims)): - output[b:b+1] = function(s).to(output_device) + output[b:b+1] = function(s.contiguous()).to(output_device) if pbar is not None: pbar.update(1) continue @@ -1125,7 +1125,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am s_in = s_in.narrow(d + 2, pos, l) upscaled.append(round(get_pos(d, pos))) - ps = function(s_in).to(output_device) + ps = function(s_in.contiguous()).to(output_device) mask = torch.ones_like(ps) for d in range(2, dims + 2):