fix(quant_ops): implement torch.Tensor.copy_ in __torch_function__ for QuantizedTensor

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
This commit is contained in:
Macpaul Lin 2026-01-08 22:59:48 +08:00
parent ea3ec049bd
commit 38f5db0118

View File

@ -128,6 +128,12 @@ except ImportError as e:
dtype = kwargs.get("dtype", input_t.dtype)
device = kwargs.get("device", input_t.device)
return torch.empty(input_t.shape, dtype=dtype, device=device)
if func is torch.Tensor.copy_:
dst, src = args[:2]
if isinstance(src, cls):
return dst.copy_(src.dequantize(), **kwargs)
return NotImplemented
def __torch_dispatch__(self, func, types, args=(), kwargs=None):