From 97f198e4215680a83749ba95849f3cdcfa7aa64a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:07:35 -0800 Subject: [PATCH 01/10] Fix qwen controlnet regression. (#10657) --- comfy/ldm/qwen_image/controlnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py index 92ac3cf0a..a6d408104 100644 --- a/comfy/ldm/qwen_image/controlnet.py +++ b/comfy/ldm/qwen_image/controlnet.py @@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel): txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint) From 1d69245981f9fb3861018613246042296d887dd3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:08:13 -0800 Subject: [PATCH 02/10] Enable pinned memory by default on Nvidia. (#10656) Removed the --fast pinned_memory flag. You can use --disable-pinned-memory to disable it. Please report if it causes any issues. --- comfy/cli_args.py | 3 ++- comfy/model_management.py | 22 +++++++++------------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 3947e62a8..2f30b72d2 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -145,10 +145,11 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" - PinnedMem = "pinned_memory" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) +parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.") + parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.") diff --git a/comfy/model_management.py b/comfy/model_management.py index 0d040e55e..4d13c52c1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1085,22 +1085,21 @@ def cast_to_device(tensor, device, dtype, copy=False): PINNED_MEMORY = {} TOTAL_PINNED_MEMORY = 0 -if PerformanceFeature.PinnedMem in args.fast: - if WINDOWS: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% - else: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 -else: - MAX_PINNED_MEMORY = -1 +MAX_PINNED_MEMORY = -1 +if not args.disable_pinned_memory: + if is_nvidia(): + if WINDOWS: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% + else: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 + logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) + def pin_memory(tensor): global TOTAL_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: return False - if not is_nvidia(): - return False - if not is_device_cpu(tensor.device): return False @@ -1121,9 +1120,6 @@ def unpin_memory(tensor): if MAX_PINNED_MEMORY <= 0: return False - if not is_nvidia(): - return False - if not is_device_cpu(tensor.device): return False From 09dc24c8a982776abd5cb2f71e3d041139e1d5b2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:11:15 -0800 Subject: [PATCH 03/10] Pinned mem also seems to work on AMD. (#10658) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 4d13c52c1..7a30c4bec 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1087,7 +1087,7 @@ PINNED_MEMORY = {} TOTAL_PINNED_MEMORY = 0 MAX_PINNED_MEMORY = -1 if not args.disable_pinned_memory: - if is_nvidia(): + if is_nvidia() or is_amd(): if WINDOWS: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% else: From e05c90712670fa4a2ffebd44046fc78747193a36 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 6 Nov 2025 01:11:30 -0800 Subject: [PATCH 04/10] Clarify release cycle. (#10667) --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4204777e9..8142f595b 100644 --- a/README.md +++ b/README.md @@ -112,10 +112,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git ## Release Process -ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: +ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** - - Releases a new stable version (e.g., v0.7.0) + - Releases a new stable version (e.g., v0.7.0) roughly every week. + - Commits outside of the stable release tags may be very unstable and break many custom nodes. - Serves as the foundation for the desktop release 2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)** From eb1c42f6498ce44aef4dbed3bb665ac98a28254d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 6 Nov 2025 17:24:28 -0800 Subject: [PATCH 05/10] Tell users they need to upload their logs in bug reports. (#10671) --- .github/ISSUE_TEMPLATE/bug-report.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 3cf2717b7..6556677e0 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -8,13 +8,15 @@ body: Before submitting a **Bug Report**, please ensure the following: - **1:** You are running the latest version of ComfyUI. - - **2:** You have looked at the existing bug reports and made sure this isn't already reported. + - **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report. - **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing - `--disable-all-custom-nodes` command line argument. + `--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version. - **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen. - If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. + ## Very Important + + Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored. - type: checkboxes id: custom-nodes-test attributes: From cf97b033ee80cf245b4592d42f89e6de67e409a4 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 7 Nov 2025 12:20:48 +1000 Subject: [PATCH 06/10] mm: guard against double pin and unpin explicitly (#10672) As commented, if you let cuda be the one to detect double pin/unpinning it actually creates an asyc GPU error. --- comfy/model_management.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7a30c4bec..a13b24cea 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1103,6 +1103,12 @@ def pin_memory(tensor): if not is_device_cpu(tensor.device): return False + if tensor.is_pinned(): + #NOTE: Cuda does detect when a tensor is already pinned and would + #error below, but there are proven cases where this also queues an error + #on the GPU async. So dont trust the CUDA API and guard here + return False + size = tensor.numel() * tensor.element_size() if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: return False @@ -1123,6 +1129,12 @@ def unpin_memory(tensor): if not is_device_cpu(tensor.device): return False + if not tensor.is_pinned(): + #NOTE: Cuda does detect when a tensor is already pinned and would + #error below, but there are proven cases where this also queues an error + #on the GPU async. So dont trust the CUDA API and guard here + return False + ptr = tensor.data_ptr() if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) From a1a70362ca376cff05a0514e0ce771ab26d92fd9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 7 Nov 2025 08:15:05 -0800 Subject: [PATCH 07/10] Only unpin tensor if it was pinned by ComfyUI (#10677) --- comfy/model_management.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a13b24cea..7012df858 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1129,13 +1129,18 @@ def unpin_memory(tensor): if not is_device_cpu(tensor.device): return False - if not tensor.is_pinned(): - #NOTE: Cuda does detect when a tensor is already pinned and would - #error below, but there are proven cases where this also queues an error - #on the GPU async. So dont trust the CUDA API and guard here + ptr = tensor.data_ptr() + size = tensor.numel() * tensor.element_size() + + size_stored = PINNED_MEMORY.get(ptr, None) + if size_stored is None: + logging.warning("Tried to unpin tensor not pinned by ComfyUI") + return False + + if size != size_stored: + logging.warning("Size of pinned tensor changed") return False - ptr = tensor.data_ptr() if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) if len(PINNED_MEMORY) == 0: From 2abd2b5c2049a9625b342bcb7decedd5d1645f66 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 8 Nov 2025 12:52:02 -0800 Subject: [PATCH 08/10] Make ScaleROPE node work on Flux. (#10686) --- comfy/ldm/flux/model.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 14f90cea5..b9d36f202 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -210,7 +210,7 @@ class Flux(nn.Module): img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img - def process_img(self, x, index=0, h_offset=0, w_offset=0): + def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): bs, c, h, w = x.shape patch_size = self.patch_size x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) @@ -222,10 +222,22 @@ class Flux(nn.Module): h_offset = ((h_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + steps_h = h_len + steps_w = w_len + + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + index += rope_options.get("shift_t", 0.0) + h_offset += rope_options.get("shift_y", 0.0) + w_offset += rope_options.get("shift_x", 0.0) + + img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype) img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0) return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): @@ -241,7 +253,7 @@ class Flux(nn.Module): h_len = ((h_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size) - img, img_ids = self.process_img(x) + img, img_ids = self.process_img(x, transformer_options=transformer_options) img_tokens = img.shape[1] if ref_latents is not None: h = 0 From e632e5de281b91dd7199636dd6d82126fbfb07d5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 9 Nov 2025 15:06:39 -0800 Subject: [PATCH 09/10] Add logging for model unloading. (#10692) --- comfy/model_patcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5a31a8734..17e06a869 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -909,6 +909,7 @@ class ModelPatcher: self.model.model_lowvram = True self.model.lowvram_patch_counter += patch_counter self.model.model_loaded_weight_memory -= memory_freed + logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter)) return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): From dea899f22125d38a8b48147d6cce89a2b659fdeb Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 9 Nov 2025 15:51:33 -0800 Subject: [PATCH 10/10] Unload weights if vram usage goes up between runs. (#10690) --- comfy/model_management.py | 11 +++++++++-- comfy/model_patcher.py | 20 +++++++++++++------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 7012df858..a4410f2ec 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -503,7 +503,11 @@ class LoadedModel: use_more_vram = lowvram_model_memory if use_more_vram == 0: use_more_vram = 1e32 - self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + if use_more_vram > 0: + self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + else: + self.model.partially_unload(self.model.offload_device, -use_more_vram, force_patch_weights=force_patch_weights) + real_model = self.model.model if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: @@ -689,7 +693,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu current_free_mem = get_free_memory(torch_dev) + loaded_memory lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) - lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) + lowvram_model_memory = lowvram_model_memory - loaded_memory + + if lowvram_model_memory == 0: + lowvram_model_memory = 0.1 if vram_set_state == VRAMState.NO_VRAM: lowvram_model_memory = 0.1 diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 17e06a869..68b0a9192 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -843,7 +843,7 @@ class ModelPatcher: self.object_patches_backup.clear() - def partially_unload(self, device_to, memory_to_free=0): + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): with self.use_ejected(): hooks_unpatched = False memory_freed = 0 @@ -887,13 +887,19 @@ class ModelPatcher: module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: - _, set_func, convert_func = get_key_weight(self.model, weight_key) - m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) - patch_counter += 1 + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + _, set_func, convert_func = get_key_weight(self.model, weight_key) + m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) + patch_counter += 1 if bias_key in self.patches: - _, set_func, convert_func = get_key_weight(self.model, bias_key) - m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) - patch_counter += 1 + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + _, set_func, convert_func = get_key_weight(self.model, bias_key) + m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) + patch_counter += 1 cast_weight = True if cast_weight: