mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
fix(quant_ops): ensure QuantizedTensor.to(dtype=...) updates orig_dtype to prevent precision mismatch RuntimeErrors
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
This commit is contained in:
parent
9907a5e4f5
commit
96803b16c0
@ -60,9 +60,12 @@ except ImportError as e:
|
|||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
device = kwargs.get("device", None)
|
device = kwargs.get("device", None)
|
||||||
if device is None and len(args) > 0:
|
dtype = kwargs.get("dtype", None)
|
||||||
|
if len(args) > 0:
|
||||||
if isinstance(args[0], (torch.device, str)):
|
if isinstance(args[0], (torch.device, str)):
|
||||||
device = args[0]
|
device = args[0]
|
||||||
|
elif isinstance(args[0], torch.dtype):
|
||||||
|
dtype = args[0]
|
||||||
|
|
||||||
new_qdata = self._qdata.to(*args, **kwargs)
|
new_qdata = self._qdata.to(*args, **kwargs)
|
||||||
new_params = self._layout_params.copy()
|
new_params = self._layout_params.copy()
|
||||||
@ -71,6 +74,9 @@ except ImportError as e:
|
|||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
new_params.__dict__[k] = v.to(device=device)
|
new_params.__dict__[k] = v.to(device=device)
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
new_params.orig_dtype = dtype
|
||||||
|
|
||||||
return type(self)(new_qdata, self._layout_type, new_params)
|
return type(self)(new_qdata, self._layout_type, new_params)
|
||||||
|
|
||||||
def detach(self):
|
def detach(self):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user