diff --git a/comfy/ops.py b/comfy/ops.py index 70ad4c712..0384c8717 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -534,7 +534,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec if dtype != MixedPrecisionOps._compute_dtype: self.comfy_cast_weights = True if self._has_bias: - self.bias = torch.nn.Parameter(torch.empty(out_features, device=device, dtype=dtype)) + self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype)) else: self.register_parameter("bias", None) else: @@ -567,7 +567,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec ) if self._has_bias: - self.bias = torch.nn.Parameter(torch.empty(out_features, device=device, dtype=MixedPrecisionOps._compute_dtype)) + self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype)) else: self.register_parameter("bias", None)