mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
fix(quant_ops): ensure QuantizedTensor.to(dtype=...) updates orig_dtype to prevent precision mismatch RuntimeErrors
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
This commit is contained in:
parent
9907a5e4f5
commit
96803b16c0
@ -60,9 +60,12 @@ except ImportError as e:
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device = kwargs.get("device", None)
|
||||
if device is None and len(args) > 0:
|
||||
dtype = kwargs.get("dtype", None)
|
||||
if len(args) > 0:
|
||||
if isinstance(args[0], (torch.device, str)):
|
||||
device = args[0]
|
||||
elif isinstance(args[0], torch.dtype):
|
||||
dtype = args[0]
|
||||
|
||||
new_qdata = self._qdata.to(*args, **kwargs)
|
||||
new_params = self._layout_params.copy()
|
||||
@ -71,6 +74,9 @@ except ImportError as e:
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params.__dict__[k] = v.to(device=device)
|
||||
|
||||
if dtype is not None:
|
||||
new_params.orig_dtype = dtype
|
||||
|
||||
return type(self)(new_qdata, self._layout_type, new_params)
|
||||
|
||||
def detach(self):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user