ops: bypass init() of weight for embedding layer

This similarly consumes large commit charge especially for TEs. It can
cause a permanement leaked commit charge which can destabilize on
systems close to the commit ceiling and generally confuses the RAM
stats.
This commit is contained in:
Rattus 2026-03-10 01:10:36 +10:00
parent 2fe5238d6e
commit 90fefbd89c

View File

@ -306,6 +306,33 @@ class CastWeightBiasOp:
bias_function = [] bias_function = []
class disable_weight_init: class disable_weight_init:
@staticmethod
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, weight_shape,
bias_shape=None):
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k, v in state_dict.items():
key = k[prefix_len:]
if key == "weight":
if not assign_to_params_buffers:
v = v.clone()
module.weight = torch.nn.Parameter(v, requires_grad=False)
elif bias_shape is not None and key == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
module.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
if module.weight is None:
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
missing_keys.append(prefix + "weight")
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
missing_keys.append(prefix + "bias")
class Linear(torch.nn.Linear, CastWeightBiasOp): class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
@ -333,29 +360,16 @@ class disable_weight_init:
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs) missing_keys, unexpected_keys, error_msgs)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) disable_weight_init._lazy_load_from_state_dict(
prefix_len = len(prefix) self,
for k,v in state_dict.items(): state_dict,
if k[prefix_len:] == "weight": prefix,
if not assign_to_params_buffers: local_metadata,
v = v.clone() missing_keys,
self.weight = torch.nn.Parameter(v, requires_grad=False) unexpected_keys,
elif k[prefix_len:] == "bias" and v is not None: weight_shape=(self.in_features, self.out_features),
if not assign_to_params_buffers: bias_shape=(self.out_features,),
v = v.clone() )
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
#Reconcile default construction of the weight if its missing.
if self.weight is None:
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"weight")
if self.bias is None and self.comfy_need_lazy_init_bias:
v = torch.zeros(self.out_features,)
self.bias = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"bias")
def reset_parameters(self): def reset_parameters(self):
@ -547,6 +561,43 @@ class disable_weight_init:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp): class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
_freeze=False, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
norm_type, scale_grad_by_freq, sparse, _weight,
_freeze, device, dtype)
return
torch.nn.Module.__init__(self)
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
self.weight = None
self.bias = None
self.weight_comfy_model_dtype = dtype
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
missing_keys,
unexpected_keys,
weight_shape=(self.num_embeddings, self.embedding_dim),
)
def reset_parameters(self): def reset_parameters(self):
self.bias = None self.bias = None
return None return None