Use is_oom_exception for all exception checks

This commit is contained in:
Alex Butler 2025-06-21 14:20:41 +01:00
parent e085cc478c
commit a19cb1a13b
5 changed files with 16 additions and 6 deletions

View File

@ -321,7 +321,9 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except model_management.OOM_EXCEPTION as e:
except Exception as e:
if not model_management.is_oom_exception(e):
raise
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:

View File

@ -232,7 +232,9 @@ def slice_attention(q, k, v):
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
except Exception as e:
if not model_management.is_oom_exception(e):
raise
model_management.soft_empty_cache(True)
steps *= 2
if steps > 128:
@ -287,7 +289,9 @@ def pytorch_attention(q, k, v):
try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
except Exception as ex:
if not model_management.is_oom_exception(ex):
raise
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out

View File

@ -169,7 +169,9 @@ def _get_attention_scores_no_kv_chunking(
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except model_management.OOM_EXCEPTION:
except Exception as ex:
if not model_management.is_oom_exception(ex):
raise
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores)

View File

@ -68,7 +68,9 @@ class ImageUpscaleWithModel:
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
except Exception as e:
if not model_management.is_oom_exception(e):
raise
tile //= 2
if tile < 128:
raise e

View File

@ -431,7 +431,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
logging.error(traceback.format_exc())
tips = ""
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
if comfy.model_management.is_oom_exception(ex):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()