diff --git a/comfy/ops.py b/comfy/ops.py index 87b36b5c5..077c73b61 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -306,6 +306,33 @@ class CastWeightBiasOp: bias_function = [] class disable_weight_init: + @staticmethod + def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata, + missing_keys, unexpected_keys, weight_shape, + bias_shape=None): + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + prefix_len = len(prefix) + for k, v in state_dict.items(): + key = k[prefix_len:] + if key == "weight": + if not assign_to_params_buffers: + v = v.clone() + module.weight = torch.nn.Parameter(v, requires_grad=False) + elif bias_shape is not None and key == "bias" and v is not None: + if not assign_to_params_buffers: + v = v.clone() + module.bias = torch.nn.Parameter(v, requires_grad=False) + else: + unexpected_keys.append(k) + + if module.weight is None: + module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False) + missing_keys.append(prefix + "weight") + + if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False): + module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False) + missing_keys.append(prefix + "bias") + class Linear(torch.nn.Linear, CastWeightBiasOp): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): @@ -333,29 +360,16 @@ class disable_weight_init: if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) - prefix_len = len(prefix) - for k,v in state_dict.items(): - if k[prefix_len:] == "weight": - if not assign_to_params_buffers: - v = v.clone() - self.weight = torch.nn.Parameter(v, requires_grad=False) - elif k[prefix_len:] == "bias" and v is not None: - if not assign_to_params_buffers: - v = v.clone() - self.bias = torch.nn.Parameter(v, requires_grad=False) - else: - unexpected_keys.append(k) - - #Reconcile default construction of the weight if its missing. - if self.weight is None: - v = torch.zeros(self.in_features, self.out_features) - self.weight = torch.nn.Parameter(v, requires_grad=False) - missing_keys.append(prefix+"weight") - if self.bias is None and self.comfy_need_lazy_init_bias: - v = torch.zeros(self.out_features,) - self.bias = torch.nn.Parameter(v, requires_grad=False) - missing_keys.append(prefix+"bias") + disable_weight_init._lazy_load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + missing_keys, + unexpected_keys, + weight_shape=(self.in_features, self.out_features), + bias_shape=(self.out_features,), + ) def reset_parameters(self): @@ -547,6 +561,43 @@ class disable_weight_init: return super().forward(*args, **kwargs) class Embedding(torch.nn.Embedding, CastWeightBiasOp): + def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, + norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, + _freeze=False, device=None, dtype=None): + if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: + super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm, + norm_type, scale_grad_by_freq, sparse, _weight, + _freeze, device, dtype) + return + + torch.nn.Module.__init__(self) + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + self.weight = None + self.bias = None + self.weight_comfy_model_dtype = dtype + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + disable_weight_init._lazy_load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + missing_keys, + unexpected_keys, + weight_shape=(self.num_embeddings, self.embedding_dim), + ) + def reset_parameters(self): self.bias = None return None