mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-08 13:20:50 +08:00
fix torch 2.2 compatibility
This commit is contained in:
parent
df271c3e33
commit
7da42cb260
@ -3,6 +3,7 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
from typing import Tuple, Dict
|
from typing import Tuple, Dict
|
||||||
from .float import stochastic_rounding as stochastic_rounding_fn
|
from .float import stochastic_rounding as stochastic_rounding_fn
|
||||||
|
from . import model_management
|
||||||
|
|
||||||
_LAYOUT_REGISTRY = {}
|
_LAYOUT_REGISTRY = {}
|
||||||
_GENERIC_UTILS = {}
|
_GENERIC_UTILS = {}
|
||||||
@ -121,22 +122,45 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
if model_management.torch_version_numeric <= (2, 2):
|
||||||
def __new__(cls, qdata, layout_type, layout_params):
|
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||||
"""
|
|
||||||
Create a quantized tensor.
|
|
||||||
|
|
||||||
Args:
|
@staticmethod
|
||||||
qdata: The quantized data tensor
|
def __new__(cls, qdata, layout_type, layout_params):
|
||||||
layout_type: Layout class (subclass of QuantizedLayout)
|
"""
|
||||||
layout_params: Dict with layout-specific parameters
|
Create a quantized tensor.
|
||||||
"""
|
|
||||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
|
||||||
|
|
||||||
def __init__(self, qdata, layout_type, layout_params):
|
Args:
|
||||||
self._qdata = qdata
|
qdata: The quantized data tensor
|
||||||
self._layout_type = layout_type
|
layout_type: Layout class (subclass of QuantizedLayout)
|
||||||
self._layout_params = layout_params
|
layout_params: Dict with layout-specific parameters
|
||||||
|
"""
|
||||||
|
obj = torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||||
|
obj._qdata = qdata
|
||||||
|
obj._layout_type = layout_type
|
||||||
|
obj._layout_params = layout_params
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, qdata, layout_type, layout_params):
|
||||||
|
"""
|
||||||
|
Create a quantized tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qdata: The quantized data tensor
|
||||||
|
layout_type: Layout class (subclass of QuantizedLayout)
|
||||||
|
layout_params: Dict with layout-specific parameters
|
||||||
|
"""
|
||||||
|
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||||
|
|
||||||
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
|
self._qdata = qdata
|
||||||
|
self._layout_type = layout_type
|
||||||
|
self._layout_params = layout_params
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
layout_name = self._layout_type
|
layout_name = self._layout_type
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user