ops: defer creation of the parameters until state dict load

If running on Windows, defer creation of the layer parameters until the state
dict is loaded. This avoids a massive charge in windows commit charge spike
when a model is created and not loaded.

This problem doesnt exist on Linux as linux allows RAM overcommit,
however windows does not. Before dynamic memory work this was also a non issue
as every non-quant model would just immediate RAM load and need the memory
anyway.

Make the workaround windows specific, as there may be someone out there with
some training from scratch workflow (which this might break), and assume said
someone is on Linux.
This commit is contained in:
Rattus 2026-01-18 19:29:58 +10:00
parent 1ce94c7260
commit c1a9b4d565

View File

@ -19,7 +19,7 @@
import torch
import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import comfy.float
import comfy.rmsnorm
import json
@ -280,6 +280,54 @@ class CastWeightBiasOp:
class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
return super().__init__(in_features, out_features, bias, device, dtype)
# Issue is with `torch.empty` still reserving the full memory for the layer.
# Windows doesn't over-commit memory so without this, We are momentarily commit
# charged for the weight even though we might zero-copy it when we load the
# state dict. If the commit charge exceeds the ceiling we can destabilize the
# system.
torch.nn.Module.__init__(self)
self.in_features = in_features
self.out_features = out_features
self.weight = None
self.bias = None
self.comfy_need_lazy_init_bias=bias
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k,v in state_dict.items():
if k[prefix_len:] == "weight":
if not assign_to_params_buffers:
v = v.clone()
self.weight = torch.nn.Parameter(v, requires_grad=False)
elif k[prefix_len:] == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
#Reconcile default construction of the weight if its missing.
if self.weight is None:
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"weight")
if self.bias is None and self.comfy_need_lazy_init_bias:
v = torch.zeros(self.out_features,)
self.bias = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"bias")
def reset_parameters(self):
return None