mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
fix(quant_ops): improve comfy_kitchen fallback logic to prevent loading errors
Signed-off-by: Macpaul Lin <macpaul@gmail.com>
This commit is contained in:
parent
ef7b4a717a
commit
406dab2d53
@ -28,20 +28,85 @@ except ImportError as e:
|
||||
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
|
||||
_CK_AVAILABLE = False
|
||||
|
||||
class ck_dummy:
|
||||
@staticmethod
|
||||
def quantize_per_tensor_fp8(tensor, scale, dtype):
|
||||
return (tensor / scale.to(tensor.device)).to(dtype)
|
||||
ck = ck_dummy
|
||||
|
||||
class QuantizedTensor:
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata
|
||||
self._layout_type = layout_type
|
||||
self._layout_params = layout_params
|
||||
self.device = qdata.device
|
||||
self.dtype = qdata.dtype
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, tensor, layout_type, **kwargs):
|
||||
layout_cls = get_layout_class(layout_type)
|
||||
if layout_cls is None:
|
||||
raise ValueError(f"Unknown layout type: {layout_type}")
|
||||
qdata, params = layout_cls.quantize(tensor, **kwargs)
|
||||
return cls(qdata, layout_type, params)
|
||||
|
||||
def dequantize(self):
|
||||
layout_cls = get_layout_class(self._layout_type)
|
||||
if layout_cls is None:
|
||||
return self._qdata
|
||||
return layout_cls.dequantize(self._qdata, **self._layout_params.__dict__)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device = kwargs.get("device", None)
|
||||
if device is None and len(args) > 0:
|
||||
if isinstance(args[0], (torch.device, str)):
|
||||
device = args[0]
|
||||
|
||||
new_qdata = self._qdata.to(*args, **kwargs)
|
||||
new_params = self._layout_params.copy()
|
||||
if device is not None:
|
||||
for k, v in new_params.__dict__.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params.__dict__[k] = v.to(device=device)
|
||||
|
||||
return type(self)(new_qdata, self._layout_type, new_params)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "params":
|
||||
return self._layout_params
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
return NotImplemented
|
||||
|
||||
class QuantizedLayout:
|
||||
class Params:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def copy(self):
|
||||
return type(self)(**self.__dict__)
|
||||
|
||||
class _CKFp8Layout(QuantizedLayout):
|
||||
pass
|
||||
|
||||
class _CKFp8Layout:
|
||||
class TensorCoreNVFP4Layout(QuantizedLayout):
|
||||
pass
|
||||
|
||||
class TensorCoreNVFP4Layout:
|
||||
pass
|
||||
_LOCAL_LAYOUT_REGISTRY = {}
|
||||
|
||||
def register_layout_class(name, cls):
|
||||
pass
|
||||
_LOCAL_LAYOUT_REGISTRY[name] = cls
|
||||
|
||||
def get_layout_class(name):
|
||||
return None
|
||||
return _LOCAL_LAYOUT_REGISTRY.get(name)
|
||||
|
||||
def register_layout_op(torch_op, layout_type):
|
||||
def decorator(handler_func):
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
import comfy.float
|
||||
import comfy.mps_ops
|
||||
|
||||
Loading…
Reference in New Issue
Block a user