fix torch 2.2 compatibility

This commit is contained in:
doctorpangloss 2025-12-17 16:33:37 -08:00
parent df271c3e33
commit 7da42cb260

View File

@ -3,6 +3,7 @@ import logging
logger = logging.getLogger(__name__)
from typing import Tuple, Dict
from .float import stochastic_rounding as stochastic_rounding_fn
from . import model_management
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
@ -121,22 +122,45 @@ class QuantizedTensor(torch.Tensor):
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
if model_management.torch_version_numeric <= (2, 2):
__torch_function__ = torch._C._disabled_torch_function_impl
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
obj = torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
obj._qdata = qdata
obj._layout_type = layout_type
obj._layout_params = layout_params
return obj
def __init__(self, qdata, layout_type, layout_params):
pass
else:
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type