mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-05 11:40:53 +08:00
fix torch 2.2 compatibility
This commit is contained in:
parent
df271c3e33
commit
7da42cb260
@ -3,6 +3,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
from typing import Tuple, Dict
|
||||
from .float import stochastic_rounding as stochastic_rounding_fn
|
||||
from . import model_management
|
||||
|
||||
_LAYOUT_REGISTRY = {}
|
||||
_GENERIC_UTILS = {}
|
||||
@ -121,22 +122,45 @@ class QuantizedTensor(torch.Tensor):
|
||||
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, qdata, layout_type, layout_params):
|
||||
"""
|
||||
Create a quantized tensor.
|
||||
if model_management.torch_version_numeric <= (2, 2):
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
Args:
|
||||
qdata: The quantized data tensor
|
||||
layout_type: Layout class (subclass of QuantizedLayout)
|
||||
layout_params: Dict with layout-specific parameters
|
||||
"""
|
||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||
@staticmethod
|
||||
def __new__(cls, qdata, layout_type, layout_params):
|
||||
"""
|
||||
Create a quantized tensor.
|
||||
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata
|
||||
self._layout_type = layout_type
|
||||
self._layout_params = layout_params
|
||||
Args:
|
||||
qdata: The quantized data tensor
|
||||
layout_type: Layout class (subclass of QuantizedLayout)
|
||||
layout_params: Dict with layout-specific parameters
|
||||
"""
|
||||
obj = torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||
obj._qdata = qdata
|
||||
obj._layout_type = layout_type
|
||||
obj._layout_params = layout_params
|
||||
return obj
|
||||
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
pass
|
||||
|
||||
else:
|
||||
@staticmethod
|
||||
def __new__(cls, qdata, layout_type, layout_params):
|
||||
"""
|
||||
Create a quantized tensor.
|
||||
|
||||
Args:
|
||||
qdata: The quantized data tensor
|
||||
layout_type: Layout class (subclass of QuantizedLayout)
|
||||
layout_params: Dict with layout-specific parameters
|
||||
"""
|
||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata
|
||||
self._layout_type = layout_type
|
||||
self._layout_params = layout_params
|
||||
|
||||
def __repr__(self):
|
||||
layout_name = self._layout_type
|
||||
|
||||
Loading…
Reference in New Issue
Block a user