ComfyUI/comfy/quant_ops.py

162 lines
5.3 KiB
Python

import torch
import logging
import dataclasses
from typing import Dict
try:
import comfy_kitchen as ck
from comfy_kitchen.tensor import (
QuantizedTensor as _CKQuantizedTensor,
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
register_layout_op,
)
_CK_AVAILABLE = True
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")
except ImportError as e:
logging.info(f"Failed to import comfy_kitchen, falling back to torch ops. Error: {e}")
_CK_AVAILABLE = False
raise ImportError(f"comfy_kitchen is required but not available: {e}")
import comfy.float
# ==============================================================================
# Backward Compatibility Layer
# ==============================================================================
class QuantizedTensor(_CKQuantizedTensor):
@staticmethod
def __new__(cls, qdata, layout_cls, params):
# Backward compat: Convert string layout names and dict params before __new__
if isinstance(layout_cls, str):
layout_cls = LAYOUTS[layout_cls]
if isinstance(params, dict):
params = layout_cls.Params(**params)
return _CKQuantizedTensor.__new__(cls, qdata, layout_cls, params)
def __init__(self, qdata, layout_cls, params):
super().__init__(qdata, layout_cls, params)
@property
def _layout_params(self) -> Dict:
return dataclasses.asdict(self._params)
@property
def _layout_type(self) -> str:
return self._layout_cls.__name__
@property
def layout_type(self) -> str:
"""Backward compatibility alias for _layout_type."""
return self._layout_type
def _copy_with(self, qdata=None, params=None, clone_params=True):
if params is None:
params = self._params.clone() if clone_params else self._params
return type(self)(
qdata if qdata is not None else self._qdata,
self._layout_cls,
params,
)
# ==============================================================================
# FP8 Layouts with Comfy-Specific Extensions
# ==============================================================================
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
FP8_DTYPE = None # Must be overridden in subclass
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if cls.FP8_DTYPE is None:
raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
if isinstance(scale, str) and scale == "recalculate":
scale = torch.amax(tensor.abs()) / torch.finfo(cls.FP8_DTYPE).max
if scale is None:
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
if stochastic_rounding > 0:
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)
else:
tensor = tensor * (1.0 / scale).to(tensor.dtype)
qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding)
else:
qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE)
params = cls.Params(scale=scale, orig_dtype=orig_dtype, orig_shape=orig_shape)
return qdata, params
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e4m3fn
class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e5m2
# Backward compatibility alias - default to E4M3
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
# ==============================================================================
# Registry
# ==============================================================================
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout, # Backward compat alias (E4M3)
"TensorCoreFP8E4M3Layout": TensorCoreFP8E4M3Layout,
"TensorCoreFP8E5M2Layout": TensorCoreFP8E5M2Layout,
"TensorCoreNVFP4Layout": TensorCoreNVFP4Layout, # Direct from comfy_kitchen
}
QUANT_ALGOS = {
"float8_e4m3fn": {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8E4M3Layout",
},
"float8_e5m2": {
"storage_t": torch.float8_e5m2,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8E5M2Layout",
},
"nvfp4": {
"storage_t": torch.uint8,
"parameters": {"weight_scale", "weight_scale_2", "input_scale"},
"comfy_tensor_layout": "TensorCoreNVFP4Layout",
"group_size": 16,
},
}
# ==============================================================================
# Re-exports for backward compatibility
# ==============================================================================
__all__ = [
"QuantizedTensor",
"QuantizedLayout",
"TensorCoreFP8Layout",
"TensorCoreFP8E4M3Layout",
"TensorCoreFP8E5M2Layout",
"TensorCoreNVFP4Layout",
"LAYOUTS",
"QUANT_ALGOS",
"register_layout_op",
]