Merge remote-tracking branch 'upstream/master' into qwen35

This commit is contained in:
kijai 2026-03-05 00:23:17 +02:00
commit 371a714747
5 changed files with 32 additions and 21 deletions

View File

@ -830,11 +830,14 @@ def unet_offload_device():
return torch.device("cpu") return torch.device("cpu")
def unet_inital_load_device(parameters, dtype): def unet_inital_load_device(parameters, dtype):
cpu_dev = torch.device("cpu")
if comfy.memory_management.aimdo_enabled:
return cpu_dev
torch_dev = get_torch_device() torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
return torch_dev return torch_dev
cpu_dev = torch.device("cpu")
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM: if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
return cpu_dev return cpu_dev
@ -842,7 +845,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev) mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev) mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled: if mem_dev > mem_cpu and model_size < mem_dev:
return torch_dev return torch_dev
else: else:
return cpu_dev return cpu_dev
@ -945,6 +948,9 @@ def text_encoder_device():
return torch.device("cpu") return torch.device("cpu")
def text_encoder_initial_device(load_device, offload_device, model_size=0): def text_encoder_initial_device(load_device, offload_device, model_size=0):
if comfy.memory_management.aimdo_enabled:
return offload_device
if load_device == offload_device or model_size <= 1024 * 1024 * 1024: if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
return offload_device return offload_device

View File

@ -307,7 +307,13 @@ class ModelPatcher:
return self.model.lowvram_patch_counter return self.model.lowvram_patch_counter
def get_free_memory(self, device): def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device) #Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In
#the vast majority of setups a little bit of offloading on the giant model more
#than pays for CFG. So return everything both torch and Aimdo could give us
aimdo_mem = 0
if comfy.memory_management.aimdo_enabled:
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
return comfy.model_management.get_free_memory(device) + aimdo_mem
def get_clone_model_override(self): def get_clone_model_override(self):
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
@ -699,7 +705,7 @@ class ModelPatcher:
for key in list(self.pinned): for key in list(self.pinned):
self.unpin_weight(key) self.unpin_weight(key)
def _load_list(self, prio_comfy_cast_weights=False, default_device=None): def _load_list(self, for_dynamic=False, default_device=None):
loading = [] loading = []
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
default = False default = False
@ -727,8 +733,13 @@ class ModelPatcher:
return 0 return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n))
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () # Dynamic: small weights (<64KB) first, then larger weights prioritized by size.
loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) # Non-dynamic: prioritize by module offload cost.
if for_dynamic:
sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem)
else:
sort_criteria = (module_offload_mem,)
loading.append(sort_criteria + (module_mem, n, m, params))
return loading return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@ -1460,12 +1471,6 @@ class ModelPatcherDynamic(ModelPatcher):
vbar = self._vbar_get() vbar = self._vbar_get()
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory
def get_free_memory(self, device):
#NOTE: on high condition / batch counts, estimate should have already vacated
#all non-dynamic models so this is safe even if its not 100% true that this
#would all be avaiable for inference use.
return comfy.model_management.get_total_memory(device) - self.model_size()
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks. #Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
def pin_weight_to_device(self, key): def pin_weight_to_device(self, key):
@ -1508,11 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None: if vbar is not None:
vbar.prioritize() vbar.prioritize()
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) loading = self._load_list(for_dynamic=True, default_device=device_to)
loading.sort(reverse=True) loading.sort()
for x in loading: for x in loading:
_, _, _, n, m, params = x *_, module_mem, n, m, params = x
def set_dirty(item, dirty): def set_dirty(item, dirty):
if dirty or not hasattr(item, "_v_signature"): if dirty or not hasattr(item, "_v_signature"):
@ -1627,9 +1632,9 @@ class ModelPatcherDynamic(ModelPatcher):
return freed return freed
def partially_unload_ram(self, ram_to_unload): def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading: for x in loading:
_, _, _, _, m, _ = x *_, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m) ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
if ram_to_unload <= 0: if ram_to_unload <= 0:
return return

View File

@ -269,8 +269,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
return return
os, weight_a, bias_a = offload_stream os, weight_a, bias_a = offload_stream
device=None device=None
#FIXME: This is not good RTTI #FIXME: This is really bad RTTI
if not isinstance(weight_a, torch.Tensor): if weight_a is not None and not isinstance(weight_a, torch.Tensor):
comfy_aimdo.model_vbar.vbar_unpin(s._v) comfy_aimdo.model_vbar.vbar_unpin(s._v)
device = weight_a device = weight_a
if os is None: if os is None:

View File

@ -429,7 +429,7 @@ class CLIP:
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options() self.cond_stage_model.reset_clip_options()
self.load_model() self.load_model(tokens)
self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)

View File

@ -22,7 +22,7 @@ alembic
SQLAlchemy SQLAlchemy
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.7 comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.4 comfy-aimdo>=0.2.6
requests requests
#non essential dependencies: #non essential dependencies: