From 06f85e2c792c626f2cab3cb4f94cd30d43e9347b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Mon, 9 Mar 2026 22:08:51 +0200 Subject: [PATCH 1/6] Fix text encoder lora loading for wrapped models (#12852) --- comfy/lora.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index f36ddb046..63ee85323 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}): for k in sdk: if k.endswith(".weight"): key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models + if tp > 0 and not k.startswith("clip_"): + key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False From 814dab9f4636df22a36cbbad21e35ac7609a0ef2 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 10 Mar 2026 10:03:22 +0800 Subject: [PATCH 2/6] Update workflow templates to v0.9.18 (#12857) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b1db1cf24..bb58f8d01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.11 +comfyui-workflow-templates==0.9.18 comfyui-embedded-docs==0.4.3 torch torchsde From 740d998c9cc821ca0a72b5b5d4b17aba1aec6b44 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Tue, 10 Mar 2026 11:49:31 +0900 Subject: [PATCH 3/6] fix(manager): improve install guidance when comfyui-manager is not installed (#12810) --- main.py | 13 ++++++++++--- manager_requirements.txt | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 1977f9362..83a7244db 100644 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ comfy.options.enable_args_parsing() import os import importlib.util +import shutil import importlib.metadata import folder_paths import time @@ -64,8 +65,15 @@ if __name__ == "__main__": def handle_comfyui_manager_unavailable(): - if not args.windows_standalone_build: - logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") + manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt") + uv_available = shutil.which("uv") is not None + + pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}" + msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}" + if uv_available: + msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}" + msg += "\n" + logging.warning(msg) args.enable_manager = False @@ -173,7 +181,6 @@ execute_prestartup_script() # Main code import asyncio -import shutil import threading import gc diff --git a/manager_requirements.txt b/manager_requirements.txt index c420cc48e..6bcc3fb50 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1b1 +comfyui_manager==4.1b2 \ No newline at end of file From c4fb0271cd7fbddb2381372b1f7c1206d1dd58fc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 9 Mar 2026 20:37:58 -0700 Subject: [PATCH 4/6] Add a way for nodes to add pre attn patches to flux model. (#12861) --- comfy/ldm/flux/layers.py | 15 ++++++++++++++- comfy/ldm/flux/math.py | 2 ++ comfy/ldm/flux/model.py | 2 +- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 8b3f500d7..e20d498f8 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module): del txt_k, img_k v = torch.cat((txt_v, img_v), dim=2) del txt_v, img_v + + extra_options["img_slice"] = [txt.shape[1], q.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # run actual attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v if "attn1_output_patch" in transformer_patches: - extra_options["img_slice"] = [txt.shape[1], attn.shape[1]] patch = transformer_patches["attn1_output_patch"] for p in patch: attn = p(attn, extra_options) @@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module): del qkv q, k = self.norm(q, k, v) + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # compute attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 5e764bb46..824daf5e6 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: def _apply_rope1(x: Tensor, freqs_cis: Tensor): x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]: + freqs_cis = freqs_cis[:, :, :x_.shape[2]] x_out = freqs_cis[..., 0] * x_[..., 0] x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index ef4dcf7c5..00f12c031 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -170,7 +170,7 @@ class Flux(nn.Module): if "post_input" in patches: for p in patches["post_input"]: - out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) + out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) img = out["img"] txt = out["txt"] img_ids = out["img_ids"] From a912809c252f5a2d69c8ab4035fc262a578fdcee Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 9 Mar 2026 20:50:10 -0700 Subject: [PATCH 5/6] model_detection: deep clone pre edited edited weights (#12862) Deep clone these weights as needed to avoid segfaulting when it tries to touch the original mmap. --- comfy/model_detection.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 6eace4628..35a6822e3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,4 +1,5 @@ import json +import comfy.memory_management import comfy.supported_models import comfy.supported_models_base import comfy.utils @@ -1118,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): new[:old_weight.shape[0]] = old_weight old_weight = new + if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled: + old_weight = old_weight.clone() + w = old_weight.narrow(offset[0], offset[1], offset[2]) else: + if comfy.memory_management.aimdo_enabled: + weight = weight.clone() old_weight = weight w = weight w[:] = fun(weight) From 535c16ce6e3d2634d6eb2fd17ecccb8d497e26a0 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 9 Mar 2026 21:41:02 -0700 Subject: [PATCH 6/6] Widen OOM_EXCEPTION to AcceleratorError form (#12835) Pytorch only filters for OOMs in its own allocators however there are paths that can OOM on allocators made outside the pytorch allocators. These manifest as an AllocatorError as pytorch does not have universal error translation to its OOM type on exception. Handle it. A log I have for this also shows a double report of the error async, so call the async discarder to cleanup and make these OOMs look like OOMs. --- comfy/ldm/modules/attention.py | 3 ++- comfy/ldm/modules/diffusionmodules/model.py | 6 ++++-- comfy/ldm/modules/sub_quadratic_attention.py | 3 ++- comfy/model_management.py | 12 ++++++++++++ comfy/sd.py | 6 ++++-- comfy_extras/nodes_upscale_model.py | 3 ++- execution.py | 2 +- 7 files changed, 27 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 10d051325..b193fe5e8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) if first_op_done == False: model_management.soft_empty_cache(True) if cleared_cache == False: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 805592aa5..fcbaa074f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -258,7 +258,8 @@ def slice_attention(q, k, v): r1[:, :, i:end] = torch.bmm(v, s2) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) model_management.soft_empty_cache(True) steps *= 2 if steps > 128: @@ -314,7 +315,8 @@ def pytorch_attention(q, k, v): try: out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") oom_fallback = True if oom_fallback: diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index fab145f1c..f982afc2b 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking( try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined torch.exp(attn_scores, out=attn_scores) diff --git a/comfy/model_management.py b/comfy/model_management.py index 07bc8ad67..81550c790 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -270,6 +270,18 @@ try: except: OOM_EXCEPTION = Exception +def is_oom(e): + if isinstance(e, OOM_EXCEPTION): + return True + if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2: + discard_cuda_async_error() + return True + return False + +def raise_non_oom(e): + if not is_oom(e): + raise e + XFORMERS_VERSION = "" XFORMERS_ENABLED_VAE = True if args.disable_xformers: diff --git a/comfy/sd.py b/comfy/sd.py index 888ef1e77..adcd67767 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -954,7 +954,8 @@ class VAE: if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples[x:x+batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. @@ -1029,7 +1030,8 @@ class VAE: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) samples[x:x + batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 97b9e948d..db4f9d231 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode): pbar = comfy.utils.ProgressBar(steps) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) oom = False - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) tile //= 2 if tile < 128: raise e diff --git a/execution.py b/execution.py index 7ccdbf93e..a7791efed 100644 --- a/execution.py +++ b/execution.py @@ -612,7 +612,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, logging.error(traceback.format_exc()) tips = "" - if isinstance(ex, comfy.model_management.OOM_EXCEPTION): + if comfy.model_management.is_oom(ex): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.error("Got an OOM, unloading all loaded models.")