diff --git a/comfy/ops.py b/comfy/ops.py index 1057cda6b..ee8b32f18 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,7 +19,7 @@ import torch import logging import comfy.model_management -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import comfy.float import comfy.rmsnorm import json @@ -280,6 +280,54 @@ class CastWeightBiasOp: class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): + + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + return super().__init__(in_features, out_features, bias, device, dtype) + + # Issue is with `torch.empty` still reserving the full memory for the layer. + # Windows doesn't over-commit memory so without this, We are momentarily commit + # charged for the weight even though we might zero-copy it when we load the + # state dict. If the commit charge exceeds the ceiling we can destabilize the + # system. + torch.nn.Module.__init__(self) + self.in_features = in_features + self.out_features = out_features + self.weight = None + self.bias = None + self.comfy_need_lazy_init_bias=bias + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + prefix_len = len(prefix) + for k,v in state_dict.items(): + if k[prefix_len:] == "weight": + if not assign_to_params_buffers: + v = v.clone() + self.weight = torch.nn.Parameter(v, requires_grad=False) + elif k[prefix_len:] == "bias" and v is not None: + if not assign_to_params_buffers: + v = v.clone() + self.bias = torch.nn.Parameter(v, requires_grad=False) + else: + unexpected_keys.append(k) + + #Reconcile default construction of the weight if its missing. + if self.weight is None: + v = torch.zeros(self.in_features, self.out_features) + self.weight = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix+"weight") + if self.bias is None and self.comfy_need_lazy_init_bias: + v = torch.zeros(self.out_features,) + self.bias = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix+"bias") + + def reset_parameters(self): return None