mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
fix(quant_ops): improve comfy_kitchen fallback logic to prevent loading errors
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
This commit is contained in:
parent
ef7b4a717a
commit
406dab2d53
@ -28,20 +28,85 @@ except ImportError as e:
|
|||||||
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
|
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
|
||||||
_CK_AVAILABLE = False
|
_CK_AVAILABLE = False
|
||||||
|
|
||||||
|
class ck_dummy:
|
||||||
|
@staticmethod
|
||||||
|
def quantize_per_tensor_fp8(tensor, scale, dtype):
|
||||||
|
return (tensor / scale.to(tensor.device)).to(dtype)
|
||||||
|
ck = ck_dummy
|
||||||
|
|
||||||
class QuantizedTensor:
|
class QuantizedTensor:
|
||||||
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
|
self._qdata = qdata
|
||||||
|
self._layout_type = layout_type
|
||||||
|
self._layout_params = layout_params
|
||||||
|
self.device = qdata.device
|
||||||
|
self.dtype = qdata.dtype
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_float(cls, tensor, layout_type, **kwargs):
|
||||||
|
layout_cls = get_layout_class(layout_type)
|
||||||
|
if layout_cls is None:
|
||||||
|
raise ValueError(f"Unknown layout type: {layout_type}")
|
||||||
|
qdata, params = layout_cls.quantize(tensor, **kwargs)
|
||||||
|
return cls(qdata, layout_type, params)
|
||||||
|
|
||||||
|
def dequantize(self):
|
||||||
|
layout_cls = get_layout_class(self._layout_type)
|
||||||
|
if layout_cls is None:
|
||||||
|
return self._qdata
|
||||||
|
return layout_cls.dequantize(self._qdata, **self._layout_params.__dict__)
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
device = kwargs.get("device", None)
|
||||||
|
if device is None and len(args) > 0:
|
||||||
|
if isinstance(args[0], (torch.device, str)):
|
||||||
|
device = args[0]
|
||||||
|
|
||||||
|
new_qdata = self._qdata.to(*args, **kwargs)
|
||||||
|
new_params = self._layout_params.copy()
|
||||||
|
if device is not None:
|
||||||
|
for k, v in new_params.__dict__.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
new_params.__dict__[k] = v.to(device=device)
|
||||||
|
|
||||||
|
return type(self)(new_qdata, self._layout_type, new_params)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name == "params":
|
||||||
|
return self._layout_params
|
||||||
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||||
|
|
||||||
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
class QuantizedLayout:
|
||||||
|
class Params:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return type(self)(**self.__dict__)
|
||||||
|
|
||||||
|
class _CKFp8Layout(QuantizedLayout):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class _CKFp8Layout:
|
class TensorCoreNVFP4Layout(QuantizedLayout):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class TensorCoreNVFP4Layout:
|
_LOCAL_LAYOUT_REGISTRY = {}
|
||||||
pass
|
|
||||||
|
|
||||||
def register_layout_class(name, cls):
|
def register_layout_class(name, cls):
|
||||||
pass
|
_LOCAL_LAYOUT_REGISTRY[name] = cls
|
||||||
|
|
||||||
def get_layout_class(name):
|
def get_layout_class(name):
|
||||||
return None
|
return _LOCAL_LAYOUT_REGISTRY.get(name)
|
||||||
|
|
||||||
|
def register_layout_op(torch_op, layout_type):
|
||||||
|
def decorator(handler_func):
|
||||||
|
return handler_func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.mps_ops
|
import comfy.mps_ops
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user