Fix weight_norm parametrizations causing incorrect partial load skip in _load_list

When a module uses torch.nn.utils.parametrizations.weight_norm, its weight
parameter is moved into a 'parametrizations' child sub-module as original0/original1.
named_parameters(recurse=True) yields 'parametrizations.weight.original0' while
named_parameters(recurse=False) does not, causing _load_list() to incorrectly
classify the module as having random weights in non-leaf sub-modules and skipping
it during partial VRAM loading.

The fix skips 'parametrizations.*' entries in the check, so weight-normed modules
are correctly included in the load list and their parameters moved to GPU as needed.

The AudioOobleckVAE (Stable Audio) had a disable_offload=True workaround that
forced full GPU load to avoid this bug. With the root cause fixed, that workaround
is no longer necessary and is removed, allowing partial VRAM offloading on systems
with limited GPU memory.

Fixes #11855
This commit is contained in:
octo-patch 2026-04-17 12:22:17 +08:00
parent d0c53c50c2
commit 7c1ec4a59e
2 changed files with 1 additions and 2 deletions

View File

@ -732,7 +732,7 @@ class ModelPatcher:
default = False
params = { name: param for name, param in m.named_parameters(recurse=False) }
for name, param in m.named_parameters(recurse=True):
if name not in params:
if name not in params and not name.startswith("parametrizations."):
default = True # default random weights in non leaf modules
break
if default and default_device is not None:

View File

@ -596,7 +596,6 @@ class VAE:
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.disable_offload = True
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})