diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 08a8a996d..545dffb30 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -28,20 +28,85 @@ except ImportError as e: logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.") _CK_AVAILABLE = False + class ck_dummy: + @staticmethod + def quantize_per_tensor_fp8(tensor, scale, dtype): + return (tensor / scale.to(tensor.device)).to(dtype) + ck = ck_dummy + class QuantizedTensor: + def __init__(self, qdata, layout_type, layout_params): + self._qdata = qdata + self._layout_type = layout_type + self._layout_params = layout_params + self.device = qdata.device + self.dtype = qdata.dtype + + @classmethod + def from_float(cls, tensor, layout_type, **kwargs): + layout_cls = get_layout_class(layout_type) + if layout_cls is None: + raise ValueError(f"Unknown layout type: {layout_type}") + qdata, params = layout_cls.quantize(tensor, **kwargs) + return cls(qdata, layout_type, params) + + def dequantize(self): + layout_cls = get_layout_class(self._layout_type) + if layout_cls is None: + return self._qdata + return layout_cls.dequantize(self._qdata, **self._layout_params.__dict__) + + def to(self, *args, **kwargs): + device = kwargs.get("device", None) + if device is None and len(args) > 0: + if isinstance(args[0], (torch.device, str)): + device = args[0] + + new_qdata = self._qdata.to(*args, **kwargs) + new_params = self._layout_params.copy() + if device is not None: + for k, v in new_params.__dict__.items(): + if isinstance(v, torch.Tensor): + new_params.__dict__[k] = v.to(device=device) + + return type(self)(new_qdata, self._layout_type, new_params) + + def __getattr__(self, name): + if name == "params": + return self._layout_params + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + return NotImplemented + + class QuantizedLayout: + class Params: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def copy(self): + return type(self)(**self.__dict__) + + class _CKFp8Layout(QuantizedLayout): pass - class _CKFp8Layout: + class TensorCoreNVFP4Layout(QuantizedLayout): pass - class TensorCoreNVFP4Layout: - pass + _LOCAL_LAYOUT_REGISTRY = {} def register_layout_class(name, cls): - pass + _LOCAL_LAYOUT_REGISTRY[name] = cls def get_layout_class(name): - return None + return _LOCAL_LAYOUT_REGISTRY.get(name) + + def register_layout_op(torch_op, layout_type): + def decorator(handler_func): + return handler_func + return decorator + import comfy.float import comfy.mps_ops