mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
fix(quant_ops): implement torch.Tensor.copy_ in __torch_function__ for QuantizedTensor
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
This commit is contained in:
parent
ea3ec049bd
commit
38f5db0118
@ -128,6 +128,12 @@ except ImportError as e:
|
|||||||
dtype = kwargs.get("dtype", input_t.dtype)
|
dtype = kwargs.get("dtype", input_t.dtype)
|
||||||
device = kwargs.get("device", input_t.device)
|
device = kwargs.get("device", input_t.device)
|
||||||
return torch.empty(input_t.shape, dtype=dtype, device=device)
|
return torch.empty(input_t.shape, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
if func is torch.Tensor.copy_:
|
||||||
|
dst, src = args[:2]
|
||||||
|
if isinstance(src, cls):
|
||||||
|
return dst.copy_(src.dequantize(), **kwargs)
|
||||||
|
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user