fix(quant_ops): ensure QuantizedTensor.to(dtype=...) updates orig_dtype to prevent precision mismatch RuntimeErrors

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
This commit is contained in:
Macpaul Lin 2026-01-08 22:56:03 +08:00
parent 9907a5e4f5
commit 96803b16c0

View File

@ -60,9 +60,12 @@ except ImportError as e:
def to(self, *args, **kwargs):
device = kwargs.get("device", None)
if device is None and len(args) > 0:
dtype = kwargs.get("dtype", None)
if len(args) > 0:
if isinstance(args[0], (torch.device, str)):
device = args[0]
elif isinstance(args[0], torch.dtype):
dtype = args[0]
new_qdata = self._qdata.to(*args, **kwargs)
new_params = self._layout_params.copy()
@ -71,6 +74,9 @@ except ImportError as e:
if isinstance(v, torch.Tensor):
new_params.__dict__[k] = v.to(device=device)
if dtype is not None:
new_params.orig_dtype = dtype
return type(self)(new_qdata, self._layout_type, new_params)
def detach(self):