diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 50f988bfd..993e0f9fc 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -3,6 +3,7 @@ import logging logger = logging.getLogger(__name__) from typing import Tuple, Dict from .float import stochastic_rounding as stochastic_rounding_fn +from . import model_management _LAYOUT_REGISTRY = {} _GENERIC_UTILS = {} @@ -121,22 +122,45 @@ class QuantizedTensor(torch.Tensor): _layout_params: Dict with layout-specific params (scale, zero_point, etc.) """ - @staticmethod - def __new__(cls, qdata, layout_type, layout_params): - """ - Create a quantized tensor. + if model_management.torch_version_numeric <= (2, 2): + __torch_function__ = torch._C._disabled_torch_function_impl - Args: - qdata: The quantized data tensor - layout_type: Layout class (subclass of QuantizedLayout) - layout_params: Dict with layout-specific parameters - """ - return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) + @staticmethod + def __new__(cls, qdata, layout_type, layout_params): + """ + Create a quantized tensor. - def __init__(self, qdata, layout_type, layout_params): - self._qdata = qdata - self._layout_type = layout_type - self._layout_params = layout_params + Args: + qdata: The quantized data tensor + layout_type: Layout class (subclass of QuantizedLayout) + layout_params: Dict with layout-specific parameters + """ + obj = torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) + obj._qdata = qdata + obj._layout_type = layout_type + obj._layout_params = layout_params + return obj + + def __init__(self, qdata, layout_type, layout_params): + pass + + else: + @staticmethod + def __new__(cls, qdata, layout_type, layout_params): + """ + Create a quantized tensor. + + Args: + qdata: The quantized data tensor + layout_type: Layout class (subclass of QuantizedLayout) + layout_params: Dict with layout-specific parameters + """ + return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) + + def __init__(self, qdata, layout_type, layout_params): + self._qdata = qdata + self._layout_type = layout_type + self._layout_params = layout_params def __repr__(self): layout_name = self._layout_type