mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 21:00:16 +08:00
Respect the dtype the op was initialized in for non quant mixed op.
This commit is contained in:
parent
982876d59a
commit
e169c4567c
11
comfy/ops.py
11
comfy/ops.py
@ -497,8 +497,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
if dtype is None:
|
||||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
dtype = MixedPrecisionOps._compute_dtype
|
||||||
|
|
||||||
|
self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
@ -530,7 +532,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||||
|
|
||||||
if layer_conf is None:
|
if layer_conf is None:
|
||||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
dtype = self.factory_kwargs["dtype"]
|
||||||
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
|
||||||
|
if dtype != MixedPrecisionOps._compute_dtype:
|
||||||
|
self.comfy_cast_weights = True
|
||||||
else:
|
else:
|
||||||
self.quant_format = layer_conf.get("format", None)
|
self.quant_format = layer_conf.get("format", None)
|
||||||
if not self._full_precision_mm:
|
if not self._full_precision_mm:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user