diff --git a/comfy/ops.py b/comfy/ops.py index 59c0df87d..f47d4137a 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -336,7 +336,10 @@ class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): - if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: + # don't trust subclasses that BYO state dict loader to call us. + if (not comfy.model_management.WINDOWS + or not comfy.memory_management.aimdo_enabled + or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict): super().__init__(in_features, out_features, bias, device, dtype) return @@ -357,7 +360,9 @@ class disable_weight_init: def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: + if (not comfy.model_management.WINDOWS + or not comfy.memory_management.aimdo_enabled + or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict): return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) disable_weight_init._lazy_load_from_state_dict( @@ -564,7 +569,10 @@ class disable_weight_init: def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None): - if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: + # don't trust subclasses that BYO state dict loader to call us. + if (not comfy.model_management.WINDOWS + or not comfy.memory_management.aimdo_enabled + or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict): super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, _freeze, device, dtype) @@ -590,7 +598,9 @@ class disable_weight_init: def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: + if (not comfy.model_management.WINDOWS + or not comfy.memory_management.aimdo_enabled + or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict): return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) disable_weight_init._lazy_load_from_state_dict(