ModelPatcherDynamic: force load non leaf weights (#12433)

The current behaviour of the default ModelPatcher is to .to a model
only if its fully loaded, which is how random non-leaf weights get
loaded in non-LowVRAM conditions.

The however means they never get loaded in dynamic_vram. In the
dynamic_vram case, force load them to the GPU.
This commit is contained in:
rattus 2026-02-12 16:51:50 -08:00 committed by GitHub
parent 4a93a62371
commit 117e214354
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -679,18 +679,19 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self, prio_comfy_cast_weights=False):
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
loading = []
for n, m in self.model.named_modules():
params = []
skip = False
for name, param in m.named_parameters(recurse=False):
params.append(name)
default = False
params = { name: param for name, param in m.named_parameters(recurse=False) }
for name, param in m.named_parameters(recurse=True):
if name not in params:
skip = True # skip random weights in non leaf modules
default = True # default random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
if default and default_device is not None:
for param in params.values():
param.data = param.data.to(device=default_device)
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
@ -1495,7 +1496,7 @@ class ModelPatcherDynamic(ModelPatcher):
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True)
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)
for x in loading:
@ -1579,7 +1580,7 @@ class ModelPatcherDynamic(ModelPatcher):
return 0 if vbar is None else vbar.free_memory(memory_to_free)
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True)
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
for x in loading:
_, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)