Compare commits

...

2 Commits

Author SHA1 Message Date
Octopus
fb46266f4c
Merge 7c1ec4a59e into fce0398470 2026-04-29 18:57:25 +08:00
octo-patch
7c1ec4a59e Fix weight_norm parametrizations causing incorrect partial load skip in _load_list
When a module uses torch.nn.utils.parametrizations.weight_norm, its weight
parameter is moved into a 'parametrizations' child sub-module as original0/original1.
named_parameters(recurse=True) yields 'parametrizations.weight.original0' while
named_parameters(recurse=False) does not, causing _load_list() to incorrectly
classify the module as having random weights in non-leaf sub-modules and skipping
it during partial VRAM loading.

The fix skips 'parametrizations.*' entries in the check, so weight-normed modules
are correctly included in the load list and their parameters moved to GPU as needed.

The AudioOobleckVAE (Stable Audio) had a disable_offload=True workaround that
forced full GPU load to avoid this bug. With the root cause fixed, that workaround
is no longer necessary and is removed, allowing partial VRAM offloading on systems
with limited GPU memory.

Fixes #11855
2026-04-17 12:22:17 +08:00
2 changed files with 1 additions and 2 deletions

View File

@ -738,7 +738,7 @@ class ModelPatcher:
default = False
params = { name: param for name, param in m.named_parameters(recurse=False) }
for name, param in m.named_parameters(recurse=True):
if name not in params:
if name not in params and not name.startswith("parametrizations."):
default = True # default random weights in non leaf modules
break
if default and default_device is not None:

View File

@ -597,7 +597,6 @@ class VAE:
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.disable_offload = True
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})