mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 00:30:21 +08:00
ops: defer creation of the parameters until state dict load
If running on Windows, defer creation of the layer parameters until the state dict is loaded. This avoids a massive charge in windows commit charge spike when a model is created and not loaded. This problem doesnt exist on Linux as linux allows RAM overcommit, however windows does not. Before dynamic memory work this was also a non issue as every non-quant model would just immediate RAM load and need the memory anyway. Make the workaround windows specific, as there may be someone out there with some training from scratch workflow (which this might break), and assume said someone is on Linux.
This commit is contained in:
parent
1ce94c7260
commit
c1a9b4d565
50
comfy/ops.py
50
comfy/ops.py
@ -19,7 +19,7 @@
|
||||
import torch
|
||||
import logging
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import json
|
||||
@ -280,6 +280,54 @@ class CastWeightBiasOp:
|
||||
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||
return super().__init__(in_features, out_features, bias, device, dtype)
|
||||
|
||||
# Issue is with `torch.empty` still reserving the full memory for the layer.
|
||||
# Windows doesn't over-commit memory so without this, We are momentarily commit
|
||||
# charged for the weight even though we might zero-copy it when we load the
|
||||
# state dict. If the commit charge exceeds the ceiling we can destabilize the
|
||||
# system.
|
||||
torch.nn.Module.__init__(self)
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
self.comfy_need_lazy_init_bias=bias
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
prefix_len = len(prefix)
|
||||
for k,v in state_dict.items():
|
||||
if k[prefix_len:] == "weight":
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif k[prefix_len:] == "bias" and v is not None:
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
unexpected_keys.append(k)
|
||||
|
||||
#Reconcile default construction of the weight if its missing.
|
||||
if self.weight is None:
|
||||
v = torch.zeros(self.in_features, self.out_features)
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"weight")
|
||||
if self.bias is None and self.comfy_need_lazy_init_bias:
|
||||
v = torch.zeros(self.out_features,)
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"bias")
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user