mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-19 06:19:38 +08:00
ops: bypass init() of weight for embedding layer
This similarly consumes large commit charge especially for TEs. It can cause a permanement leaked commit charge which can destabilize on systems close to the commit ceiling and generally confuses the RAM stats.
This commit is contained in:
parent
2fe5238d6e
commit
90fefbd89c
97
comfy/ops.py
97
comfy/ops.py
@ -306,6 +306,33 @@ class CastWeightBiasOp:
|
|||||||
bias_function = []
|
bias_function = []
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
|
@staticmethod
|
||||||
|
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
||||||
|
missing_keys, unexpected_keys, weight_shape,
|
||||||
|
bias_shape=None):
|
||||||
|
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||||
|
prefix_len = len(prefix)
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
key = k[prefix_len:]
|
||||||
|
if key == "weight":
|
||||||
|
if not assign_to_params_buffers:
|
||||||
|
v = v.clone()
|
||||||
|
module.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
elif bias_shape is not None and key == "bias" and v is not None:
|
||||||
|
if not assign_to_params_buffers:
|
||||||
|
v = v.clone()
|
||||||
|
module.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
else:
|
||||||
|
unexpected_keys.append(k)
|
||||||
|
|
||||||
|
if module.weight is None:
|
||||||
|
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
|
||||||
|
missing_keys.append(prefix + "weight")
|
||||||
|
|
||||||
|
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
|
||||||
|
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
|
||||||
|
missing_keys.append(prefix + "bias")
|
||||||
|
|
||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||||
@ -333,29 +360,16 @@ class disable_weight_init:
|
|||||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
disable_weight_init._lazy_load_from_state_dict(
|
||||||
prefix_len = len(prefix)
|
self,
|
||||||
for k,v in state_dict.items():
|
state_dict,
|
||||||
if k[prefix_len:] == "weight":
|
prefix,
|
||||||
if not assign_to_params_buffers:
|
local_metadata,
|
||||||
v = v.clone()
|
missing_keys,
|
||||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
unexpected_keys,
|
||||||
elif k[prefix_len:] == "bias" and v is not None:
|
weight_shape=(self.in_features, self.out_features),
|
||||||
if not assign_to_params_buffers:
|
bias_shape=(self.out_features,),
|
||||||
v = v.clone()
|
)
|
||||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
|
||||||
else:
|
|
||||||
unexpected_keys.append(k)
|
|
||||||
|
|
||||||
#Reconcile default construction of the weight if its missing.
|
|
||||||
if self.weight is None:
|
|
||||||
v = torch.zeros(self.in_features, self.out_features)
|
|
||||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
|
||||||
missing_keys.append(prefix+"weight")
|
|
||||||
if self.bias is None and self.comfy_need_lazy_init_bias:
|
|
||||||
v = torch.zeros(self.out_features,)
|
|
||||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
|
||||||
missing_keys.append(prefix+"bias")
|
|
||||||
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
@ -547,6 +561,43 @@ class disable_weight_init:
|
|||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
||||||
|
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
|
||||||
|
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
|
||||||
|
_freeze=False, device=None, dtype=None):
|
||||||
|
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||||
|
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
|
||||||
|
norm_type, scale_grad_by_freq, sparse, _weight,
|
||||||
|
_freeze, device, dtype)
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
self.num_embeddings = num_embeddings
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.padding_idx = padding_idx
|
||||||
|
self.max_norm = max_norm
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.scale_grad_by_freq = scale_grad_by_freq
|
||||||
|
self.sparse = sparse
|
||||||
|
self.weight = None
|
||||||
|
self.bias = None
|
||||||
|
self.weight_comfy_model_dtype = dtype
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
|
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||||
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
|
disable_weight_init._lazy_load_from_state_dict(
|
||||||
|
self,
|
||||||
|
state_dict,
|
||||||
|
prefix,
|
||||||
|
local_metadata,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
weight_shape=(self.num_embeddings, self.embedding_dim),
|
||||||
|
)
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
self.bias = None
|
self.bias = None
|
||||||
return None
|
return None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user