From d531e3fb2a885d675d5b6d3a496b4af5d9757af1 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 07:47:44 -0800 Subject: [PATCH 1/5] model_patcher: Improve dynamic offload heuristic (#12759) Define a threshold below which a weight loading takes priority. This actually makes the offload consistent with non-dynamic, because what happens, is when non-dynamic fills ints to_load list, it will fill-up any left-over pieces that could fix large weights with small weights and load them, even though they were lower priority. This actually improves performance because the timy weights dont cost any VRAM and arent worth the control overhead of the DMA etc. --- comfy/model_patcher.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 70f78a089..168ce8430 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -699,7 +699,7 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self, prio_comfy_cast_weights=False, default_device=None): + def _load_list(self, for_dynamic=False, default_device=None): loading = [] for n, m in self.model.named_modules(): default = False @@ -727,8 +727,13 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () - loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) + # Dynamic: small weights (<64KB) first, then larger weights prioritized by size. + # Non-dynamic: prioritize by module offload cost. + if for_dynamic: + sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem) + else: + sort_criteria = (module_offload_mem,) + loading.append(sort_criteria + (module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -1508,11 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None: vbar.prioritize() - loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) - loading.sort(reverse=True) + loading = self._load_list(for_dynamic=True, default_device=device_to) + loading.sort() for x in loading: - _, _, _, n, m, params = x + *_, module_mem, n, m, params = x def set_dirty(item, dirty): if dirty or not hasattr(item, "_v_signature"): @@ -1627,9 +1632,9 @@ class ModelPatcherDynamic(ModelPatcher): return freed def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) + loading = self._load_list(for_dynamic=True, default_device=self.offload_device) for x in loading: - _, _, _, _, m, _ = x + *_, m, _ = x ram_to_unload -= comfy.pinned_memory.unpin_memory(m) if ram_to_unload <= 0: return From 9b85cf955858b0aca6b7b30c30b404470ea0c964 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 07:49:13 -0800 Subject: [PATCH 2/5] Comfy Aimdo 0.2.5 + Fix offload performance in DynamicVram (#12754) * ops: dont unpin nothing This was calling into aimdo in the none case (offloaded weight). Whats worse, is aimdo syncs for unpinning an offloaded weight, as that is the corner case of a weight getting evicted by its own use which does require a sync. But this was heppening every offloaded weight causing slowdown. * mp: fix get_free_memory policy The ModelPatcherDynamic get_free_memory was deducting the model from to try and estimate the conceptual free memory with doing any offloading. This is kind of what the old memory_memory_required was estimating in ModelPatcher load logic, however in practical reality, between over-estimates and padding, the loader usually underloaded models enough such that sampling could send CFG +/- through together even when partially loaded. So don't regress from the status quo and instead go all in on the idea that offloading is less of an issue than debatching. Tell the sampler it can use everything. --- comfy/model_patcher.py | 14 +++++++------- comfy/ops.py | 4 ++-- requirements.txt | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 168ce8430..7e5ad7aa4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -307,7 +307,13 @@ class ModelPatcher: return self.model.lowvram_patch_counter def get_free_memory(self, device): - return comfy.model_management.get_free_memory(device) + #Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In + #the vast majority of setups a little bit of offloading on the giant model more + #than pays for CFG. So return everything both torch and Aimdo could give us + aimdo_mem = 0 + if comfy.memory_management.aimdo_enabled: + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + return comfy.model_management.get_free_memory(device) + aimdo_mem def get_clone_model_override(self): return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) @@ -1465,12 +1471,6 @@ class ModelPatcherDynamic(ModelPatcher): vbar = self._vbar_get() return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory - def get_free_memory(self, device): - #NOTE: on high condition / batch counts, estimate should have already vacated - #all non-dynamic models so this is safe even if its not 100% true that this - #would all be avaiable for inference use. - return comfy.model_management.get_total_memory(device) - self.model_size() - #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. def pin_weight_to_device(self, key): diff --git a/comfy/ops.py b/comfy/ops.py index 6ee6075fb..8275dd0a5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -269,8 +269,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream): return os, weight_a, bias_a = offload_stream device=None - #FIXME: This is not good RTTI - if not isinstance(weight_a, torch.Tensor): + #FIXME: This is really bad RTTI + if weight_a is not None and not isinstance(weight_a, torch.Tensor): comfy_aimdo.model_vbar.vbar_unpin(s._v) device = weight_a if os is None: diff --git a/requirements.txt b/requirements.txt index 608b0cfa6..110568cd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.4 +comfy-aimdo>=0.2.5 requests #non essential dependencies: From 0a7446ade4bbeecfaf36e9a70eeabbeb0f6e59ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:59:56 +0200 Subject: [PATCH 3/5] Pass tokens when loading text gen model for text generation (#12755) Co-authored-by: Jedrzej Kosinski --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index a9ad7c2d2..8bcd09582 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -428,7 +428,7 @@ class CLIP: def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): self.cond_stage_model.reset_clip_options() - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) From 8811db52db5d0aea49c1dbedd733a6b9304b83a9 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 4 Mar 2026 12:12:37 -0800 Subject: [PATCH 4/5] comfy-aimdo 0.2.6 (#12764) Comfy Aimdo 0.2.6 fixes a GPU virtual address leak. This would manfiest as an error after a number of workflow runs. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 110568cd3..dae46d873 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.5 +comfy-aimdo>=0.2.6 requests #non essential dependencies: From ac4a943ff364885166def5d418582db971554caf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 4 Mar 2026 13:33:14 -0800 Subject: [PATCH 5/5] Initial load device should be cpu when using dynamic vram. (#12766) --- comfy/model_management.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0f5966371..809600815 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -830,11 +830,14 @@ def unet_offload_device(): return torch.device("cpu") def unet_inital_load_device(parameters, dtype): + cpu_dev = torch.device("cpu") + if comfy.memory_management.aimdo_enabled: + return cpu_dev + torch_dev = get_torch_device() if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: return torch_dev - cpu_dev = torch.device("cpu") if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM: return cpu_dev @@ -842,7 +845,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled: + if mem_dev > mem_cpu and model_size < mem_dev: return torch_dev else: return cpu_dev @@ -945,6 +948,9 @@ def text_encoder_device(): return torch.device("cpu") def text_encoder_initial_device(load_device, offload_device, model_size=0): + if comfy.memory_management.aimdo_enabled: + return offload_device + if load_device == offload_device or model_size <= 1024 * 1024 * 1024: return offload_device