From 96803b16c0b3d8a9e781015e48c8e682d0679cba Mon Sep 17 00:00:00 2001 From: Macpaul Lin Date: Thu, 8 Jan 2026 22:56:03 +0800 Subject: [PATCH] fix(quant_ops): ensure QuantizedTensor.to(dtype=...) updates orig_dtype to prevent precision mismatch RuntimeErrors Signed-off-by: Macpaul Lin --- comfy/quant_ops.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index e80e6bcdc..8f282ee5b 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -60,9 +60,12 @@ except ImportError as e: def to(self, *args, **kwargs): device = kwargs.get("device", None) - if device is None and len(args) > 0: + dtype = kwargs.get("dtype", None) + if len(args) > 0: if isinstance(args[0], (torch.device, str)): device = args[0] + elif isinstance(args[0], torch.dtype): + dtype = args[0] new_qdata = self._qdata.to(*args, **kwargs) new_params = self._layout_params.copy() @@ -71,6 +74,9 @@ except ImportError as e: if isinstance(v, torch.Tensor): new_params.__dict__[k] = v.to(device=device) + if dtype is not None: + new_params.orig_dtype = dtype + return type(self)(new_qdata, self._layout_type, new_params) def detach(self):