mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 03:23:00 +08:00
add more tests by comparing with manual torch implementation add perf benchmarks fix errors caused by merging default no output quant fix unittest
1511 lines
61 KiB
Python
1511 lines
61 KiB
Python
import torch
|
|
import logging
|
|
from typing import Tuple, Dict
|
|
import comfy.float
|
|
|
|
_LAYOUT_REGISTRY = {}
|
|
_GENERIC_UTILS = {}
|
|
|
|
# Try to import Triton-based INT8 kernels
|
|
try:
|
|
from .int8_kernels import (
|
|
act_quant as triton_act_quant,
|
|
act_dequant as triton_act_dequant,
|
|
weight_quant as triton_weight_quant,
|
|
weight_dequant as triton_weight_dequant,
|
|
int8_gemm as triton_int8_gemm,
|
|
int8_addmm as triton_int8_addmm,
|
|
int8_gemm_quant as triton_int8_gemm_quant,
|
|
int8_addmm_quant as triton_int8_addmm_quant
|
|
)
|
|
_HAS_TRITON_INT8 = True
|
|
except ImportError:
|
|
_HAS_TRITON_INT8 = False
|
|
logging.warning("Triton INT8 kernels not available, using PyTorch fallback")
|
|
|
|
|
|
def register_layout_op(torch_op, layout_type):
|
|
"""
|
|
Decorator to register a layout-specific operation handler.
|
|
Args:
|
|
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
|
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
|
Example:
|
|
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
|
def fp8_linear(func, args, kwargs):
|
|
# FP8-specific linear implementation
|
|
...
|
|
"""
|
|
def decorator(handler_func):
|
|
if torch_op not in _LAYOUT_REGISTRY:
|
|
_LAYOUT_REGISTRY[torch_op] = {}
|
|
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
|
return handler_func
|
|
return decorator
|
|
|
|
|
|
def register_generic_util(torch_op):
|
|
"""
|
|
Decorator to register a generic utility that works for all layouts.
|
|
Args:
|
|
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
|
|
|
Example:
|
|
@register_generic_util(torch.ops.aten.detach.default)
|
|
def generic_detach(func, args, kwargs):
|
|
# Works for any layout
|
|
...
|
|
"""
|
|
def decorator(handler_func):
|
|
_GENERIC_UTILS[torch_op] = handler_func
|
|
return handler_func
|
|
return decorator
|
|
|
|
|
|
def _get_layout_from_args(args):
|
|
for arg in args:
|
|
if isinstance(arg, QuantizedTensor):
|
|
return arg._layout_type
|
|
elif isinstance(arg, (list, tuple)):
|
|
for item in arg:
|
|
if isinstance(item, QuantizedTensor):
|
|
return item._layout_type
|
|
return None
|
|
|
|
|
|
def _move_layout_params_to_device(params, device):
|
|
new_params = {}
|
|
for k, v in params.items():
|
|
if isinstance(v, torch.Tensor):
|
|
new_params[k] = v.to(device=device)
|
|
else:
|
|
new_params[k] = v
|
|
return new_params
|
|
|
|
|
|
def _copy_layout_params(params):
|
|
new_params = {}
|
|
for k, v in params.items():
|
|
if isinstance(v, torch.Tensor):
|
|
new_params[k] = v.clone()
|
|
else:
|
|
new_params[k] = v
|
|
return new_params
|
|
|
|
def _copy_layout_params_inplace(src, dst, non_blocking=False):
|
|
for k, v in src.items():
|
|
if isinstance(v, torch.Tensor):
|
|
dst[k].copy_(v, non_blocking=non_blocking)
|
|
else:
|
|
dst[k] = v
|
|
|
|
class QuantizedLayout:
|
|
"""
|
|
Base class for quantization layouts.
|
|
|
|
A layout encapsulates the format-specific logic for quantization/dequantization
|
|
and provides a uniform interface for extracting raw tensors needed for computation.
|
|
|
|
New quantization formats should subclass this and implement the required methods.
|
|
"""
|
|
@classmethod
|
|
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
|
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
|
|
|
@staticmethod
|
|
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
|
raise NotImplementedError("TensorLayout must implement dequantize()")
|
|
|
|
@classmethod
|
|
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
|
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
|
|
|
|
|
class QuantizedTensor(torch.Tensor):
|
|
"""
|
|
Universal quantized tensor that works with any layout.
|
|
|
|
This tensor subclass uses a pluggable layout system to support multiple
|
|
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
|
|
|
The layout_type determines format-specific behavior, while common operations
|
|
(detach, clone, to) are handled generically.
|
|
|
|
Attributes:
|
|
_qdata: The quantized tensor data
|
|
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
|
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
|
"""
|
|
|
|
@staticmethod
|
|
def __new__(cls, qdata, layout_type, layout_params):
|
|
"""
|
|
Create a quantized tensor.
|
|
|
|
Args:
|
|
qdata: The quantized data tensor
|
|
layout_type: Layout class (subclass of QuantizedLayout)
|
|
layout_params: Dict with layout-specific parameters
|
|
"""
|
|
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
|
|
|
def __init__(self, qdata, layout_type, layout_params):
|
|
self._qdata = qdata
|
|
self._layout_type = layout_type
|
|
self._layout_params = layout_params
|
|
|
|
def __repr__(self):
|
|
layout_name = self._layout_type
|
|
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
|
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
|
|
|
@property
|
|
def layout_type(self):
|
|
return self._layout_type
|
|
|
|
def __tensor_flatten__(self):
|
|
"""
|
|
Tensor flattening protocol for proper device movement.
|
|
"""
|
|
inner_tensors = ["_qdata"]
|
|
ctx = {
|
|
"layout_type": self._layout_type,
|
|
}
|
|
|
|
tensor_params = {}
|
|
non_tensor_params = {}
|
|
for k, v in self._layout_params.items():
|
|
if isinstance(v, torch.Tensor):
|
|
tensor_params[k] = v
|
|
else:
|
|
non_tensor_params[k] = v
|
|
|
|
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
|
ctx["non_tensor_params"] = non_tensor_params
|
|
|
|
for k, v in tensor_params.items():
|
|
attr_name = f"_layout_param_{k}"
|
|
object.__setattr__(self, attr_name, v)
|
|
inner_tensors.append(attr_name)
|
|
|
|
return inner_tensors, ctx
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
|
"""
|
|
Tensor unflattening protocol for proper device movement.
|
|
Reconstructs the QuantizedTensor after device movement.
|
|
"""
|
|
layout_type = ctx["layout_type"]
|
|
layout_params = dict(ctx["non_tensor_params"])
|
|
|
|
for key in ctx["tensor_param_keys"]:
|
|
attr_name = f"_layout_param_{key}"
|
|
layout_params[key] = inner_tensors[attr_name]
|
|
|
|
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
|
|
|
|
@classmethod
|
|
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
|
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
|
return cls(qdata, layout_type, layout_params)
|
|
|
|
def dequantize(self) -> torch.Tensor:
|
|
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
|
|
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
|
if func in _GENERIC_UTILS:
|
|
return _GENERIC_UTILS[func](func, args, kwargs)
|
|
|
|
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
|
layout_type = _get_layout_from_args(args)
|
|
if layout_type and func in _LAYOUT_REGISTRY:
|
|
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
|
if handler:
|
|
return handler(func, args, kwargs)
|
|
|
|
# Step 3: Fallback to dequantization
|
|
#if isinstance(args[0] if args else None, QuantizedTensor):
|
|
#logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}, args={args}")
|
|
|
|
to_return = cls._dequant_and_fallback(func, args, kwargs)
|
|
|
|
return to_return
|
|
|
|
def data_ptr(self):
|
|
return self._qdata.data_ptr()
|
|
|
|
def is_pinned(self):
|
|
return self._qdata.is_pinned()
|
|
|
|
def is_contiguous(self):
|
|
return self._qdata.is_contiguous()
|
|
|
|
@classmethod
|
|
def _dequant_and_fallback(cls, func, args, kwargs):
|
|
def dequant_arg(arg):
|
|
if isinstance(arg, QuantizedTensor):
|
|
return arg.dequantize()
|
|
elif isinstance(arg, (list, tuple)):
|
|
return type(arg)(dequant_arg(a) for a in arg)
|
|
return arg
|
|
|
|
new_args = dequant_arg(args)
|
|
new_kwargs = dequant_arg(kwargs)
|
|
return func(*new_args, **new_kwargs)
|
|
|
|
|
|
# ==============================================================================
|
|
# Generic Utilities (Layout-Agnostic Operations)
|
|
# ==============================================================================
|
|
|
|
def _create_transformed_qtensor(qt, transform_fn):
|
|
new_data = transform_fn(qt._qdata)
|
|
new_params = _copy_layout_params(qt._layout_params)
|
|
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
|
|
|
|
|
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
|
if target_layout is not None and target_layout != torch.strided:
|
|
logging.warning(
|
|
f"QuantizedTensor: layout change requested to {target_layout}, "
|
|
f"but not supported. Ignoring layout."
|
|
)
|
|
|
|
# Handle device transfer
|
|
current_device = qt._qdata.device
|
|
if target_device is not None:
|
|
# Normalize device for comparison
|
|
if isinstance(target_device, str):
|
|
target_device = torch.device(target_device)
|
|
if isinstance(current_device, str):
|
|
current_device = torch.device(current_device)
|
|
|
|
if target_device != current_device:
|
|
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
|
new_q_data = qt._qdata.to(device=target_device)
|
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
|
if target_dtype is not None:
|
|
new_params["orig_dtype"] = target_dtype
|
|
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
|
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
|
return new_qt
|
|
|
|
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
|
return qt
|
|
|
|
|
|
@register_generic_util(torch.ops.aten.detach.default)
|
|
def generic_detach(func, args, kwargs):
|
|
"""Detach operation - creates a detached copy of the quantized tensor."""
|
|
qt = args[0]
|
|
if isinstance(qt, QuantizedTensor):
|
|
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_generic_util(torch.ops.aten.clone.default)
|
|
def generic_clone(func, args, kwargs):
|
|
"""Clone operation - creates a deep copy of the quantized tensor."""
|
|
qt = args[0]
|
|
if isinstance(qt, QuantizedTensor):
|
|
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_generic_util(torch.ops.aten._to_copy.default)
|
|
def generic_to_copy(func, args, kwargs):
|
|
"""Device/dtype transfer operation - handles .to(device) calls."""
|
|
qt = args[0]
|
|
if isinstance(qt, QuantizedTensor):
|
|
return _handle_device_transfer(
|
|
qt,
|
|
target_device=kwargs.get('device', None),
|
|
target_dtype=kwargs.get('dtype', None),
|
|
op_name="_to_copy"
|
|
)
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
|
def generic_to_dtype_layout(func, args, kwargs):
|
|
"""Handle .to(device) calls using the dtype_layout variant."""
|
|
qt = args[0]
|
|
if isinstance(qt, QuantizedTensor):
|
|
return _handle_device_transfer(
|
|
qt,
|
|
target_device=kwargs.get('device', None),
|
|
target_dtype=kwargs.get('dtype', None),
|
|
target_layout=kwargs.get('layout', None),
|
|
op_name="to"
|
|
)
|
|
return func(*args, **kwargs)
|
|
|
|
@register_generic_util(torch.ops.aten.to.dtype)
|
|
def generic_to_dtype(func, args, kwargs):
|
|
"""Handle .to(dtype) calls - dtype conversion only."""
|
|
src = args[0]
|
|
if isinstance(src, QuantizedTensor):
|
|
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
|
|
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
|
|
src._layout_params["orig_dtype"] = target_dtype
|
|
return src
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_generic_util(torch.ops.aten.copy_.default)
|
|
def generic_copy_(func, args, kwargs):
|
|
qt_dest = args[0]
|
|
src = args[1]
|
|
non_blocking = args[2] if len(args) > 2 else False
|
|
if isinstance(qt_dest, QuantizedTensor):
|
|
if isinstance(src, QuantizedTensor):
|
|
# Copy from another quantized tensor
|
|
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
|
qt_dest._layout_type = src._layout_type
|
|
orig_dtype = qt_dest._layout_params["orig_dtype"]
|
|
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
|
qt_dest._layout_params["orig_dtype"] = orig_dtype
|
|
else:
|
|
# Copy from regular tensor - just copy raw data
|
|
qt_dest._qdata.copy_(src)
|
|
return qt_dest
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
|
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
|
return True
|
|
|
|
|
|
@register_generic_util(torch.ops.aten.empty_like.default)
|
|
def generic_empty_like(func, args, kwargs):
|
|
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
|
|
qt = args[0]
|
|
if isinstance(qt, QuantizedTensor):
|
|
# Create empty tensor with same shape and dtype as the quantized data
|
|
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
|
|
new_qdata = torch.empty_like(qt._qdata, **kwargs)
|
|
|
|
# Handle device transfer for layout params
|
|
target_device = kwargs.get('device', new_qdata.device)
|
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
|
|
|
# Update orig_dtype if dtype is specified
|
|
new_params['orig_dtype'] = hp_dtype
|
|
|
|
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
|
|
return func(*args, **kwargs)
|
|
|
|
# ==============================================================================
|
|
# FP8 Layout + Operation Handlers
|
|
# ==============================================================================
|
|
class TensorCoreFP8Layout(QuantizedLayout):
|
|
"""
|
|
Storage format:
|
|
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
|
- scale: Scalar tensor (float32) for dequantization
|
|
- orig_dtype: Original dtype before quantization (for casting back)
|
|
"""
|
|
@classmethod
|
|
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
|
orig_dtype = tensor.dtype
|
|
|
|
if isinstance(scale, str) and scale == "recalculate":
|
|
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
|
|
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
|
|
tensor_info = torch.finfo(tensor.dtype)
|
|
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
|
|
|
|
if scale is not None:
|
|
if not isinstance(scale, torch.Tensor):
|
|
scale = torch.tensor(scale)
|
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
|
|
|
if inplace_ops:
|
|
tensor *= (1.0 / scale).to(tensor.dtype)
|
|
else:
|
|
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
|
else:
|
|
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
|
|
|
if stochastic_rounding > 0:
|
|
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
|
else:
|
|
lp_amax = torch.finfo(dtype).max
|
|
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
|
|
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
|
|
|
|
layout_params = {
|
|
'scale': scale,
|
|
'orig_dtype': orig_dtype
|
|
}
|
|
return tensor, layout_params
|
|
|
|
@staticmethod
|
|
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
|
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
|
plain_tensor.mul_(scale)
|
|
return plain_tensor
|
|
|
|
@classmethod
|
|
def get_plain_tensors(cls, qtensor):
|
|
return qtensor._qdata, qtensor._layout_params['scale']
|
|
|
|
|
|
# ==============================================================================
|
|
# Block-Wise INT8 Layout + Operation Handlers
|
|
# ==============================================================================
|
|
class BlockWiseINT8Layout(QuantizedLayout):
|
|
"""
|
|
Block-wise INT8 quantization layout.
|
|
|
|
Storage format:
|
|
- qdata: INT8 tensor (torch.int8)
|
|
- scale: Per-block scaling factors (float32)
|
|
- block_size: Size of quantization blocks (default 128)
|
|
- orig_dtype: Original dtype before quantization (for casting back)
|
|
- is_weight: Whether this is a weight tensor (affects blocking dimension)
|
|
|
|
Asymmetric blocking:
|
|
- Weights: blocks partition along first dimension (M) and second dimension (N)
|
|
scale shape: (M//block_size, N//block_size)
|
|
- Activations: blocks partition along last dimension (K)
|
|
scale shape: (*batch_dims, K//block_size)
|
|
"""
|
|
|
|
@classmethod
|
|
def quantize(cls, tensor, scale=None, block_size=128, is_weight=False, **kwargs):
|
|
"""
|
|
Quantize a tensor to INT8 with block-wise scaling.
|
|
|
|
Args:
|
|
tensor: Input tensor to quantize
|
|
scale: Optional pre-computed scaling factors
|
|
block_size: Size of quantization blocks (default 128)
|
|
is_weight: If True, block along both dimensions (for weights)
|
|
If False, block along last dimension only (for activations)
|
|
|
|
Returns:
|
|
Tuple of (quantized_data, layout_params)
|
|
"""
|
|
orig_dtype = tensor.dtype
|
|
|
|
if not tensor.is_contiguous():
|
|
tensor = tensor.contiguous()
|
|
|
|
if is_weight:
|
|
# Weight quantization: block-wise along both M and N dimensions
|
|
# Expected shape: (M, N)
|
|
assert tensor.dim() == 2, f"Weight tensor must be 2D, got shape {tensor.shape}"
|
|
M, N = tensor.shape
|
|
assert M % block_size == 0 and N % block_size == 0, \
|
|
f"Dimensions must be divisible by block_size={block_size}, got shape {tensor.shape}"
|
|
|
|
# Use Triton kernel if available AND tensor is on CUDA
|
|
if _HAS_TRITON_INT8 and scale is None and tensor.is_cuda:
|
|
try:
|
|
qdata, scale = triton_weight_quant(tensor, block_size=block_size)
|
|
except Exception as e:
|
|
# don't fall back, raise, for easier debugging
|
|
logging.warning(f"Triton weight_quant failed: {e}, falling back to PyTorch")
|
|
raise e
|
|
# qdata, scale = cls._weight_quantize_pytorch(tensor, block_size)
|
|
else:
|
|
qdata, scale = cls._weight_quantize_pytorch(tensor, block_size, scale)
|
|
|
|
else:
|
|
# Activation quantization: block-wise along last dimension (K)
|
|
# Can handle any shape: (*batch_dims, K)
|
|
K = tensor.shape[-1]
|
|
assert K % block_size == 0, \
|
|
f"Last dimension must be divisible by block_size={block_size}, got {K}"
|
|
|
|
# Use Triton kernel if available AND tensor is on CUDA
|
|
# ignore input scale for now
|
|
# TODO: why do we need input scale?
|
|
if _HAS_TRITON_INT8 and tensor.is_cuda:
|
|
try:
|
|
qdata, scale = triton_act_quant(tensor, block_size=block_size)
|
|
except Exception as e:
|
|
logging.warning(f"Triton act_quant failed: {e}, falling back to PyTorch")
|
|
qdata, scale = cls._activation_quantize_pytorch(tensor, block_size)
|
|
else:
|
|
qdata, scale = cls._activation_quantize_pytorch(tensor, block_size, scale)
|
|
|
|
layout_params = {
|
|
'scale': scale.to(torch.float32),
|
|
'block_size': block_size,
|
|
'is_weight': is_weight,
|
|
'orig_dtype': orig_dtype
|
|
}
|
|
|
|
return qdata, layout_params
|
|
|
|
@staticmethod
|
|
def _weight_quantize_pytorch(tensor, block_size, scale=None):
|
|
"""PyTorch fallback for weight quantization"""
|
|
M, N = tensor.shape
|
|
# Reshape to (M//block_size, block_size, N//block_size, block_size)
|
|
tensor_blocked = tensor.reshape(M // block_size, block_size, N // block_size, block_size)
|
|
# Permute to (M//block_size, N//block_size, block_size, block_size)
|
|
tensor_blocked = tensor_blocked.permute(0, 2, 1, 3)
|
|
|
|
if scale is None:
|
|
# Compute per-block absolute maximum
|
|
amax = tensor_blocked.abs().amax(dim=(-2, -1))
|
|
scale = amax / 127.0
|
|
scale = torch.maximum(scale, torch.tensor(1e-8, device=scale.device, dtype=scale.dtype))
|
|
|
|
# Broadcast scale for division: (M//block_size, N//block_size, 1, 1)
|
|
scale_broadcast = scale.unsqueeze(-1).unsqueeze(-1)
|
|
tensor_scaled = tensor_blocked / scale_broadcast
|
|
|
|
# Clamp and convert to int8
|
|
tensor_scaled = torch.clamp(tensor_scaled, -127.0, 127.0)
|
|
qdata = tensor_scaled.to(torch.int8)
|
|
|
|
# Reshape back to original shape
|
|
qdata = qdata.permute(0, 2, 1, 3).reshape(M, N)
|
|
return qdata, scale
|
|
|
|
@staticmethod
|
|
def _activation_quantize_pytorch(tensor, block_size, scale=None):
|
|
"""PyTorch fallback for activation quantization"""
|
|
K = tensor.shape[-1]
|
|
batch_shape = tensor.shape[:-1]
|
|
tensor_blocked = tensor.reshape(*batch_shape, K // block_size, block_size)
|
|
|
|
if scale is None:
|
|
# Compute per-block absolute maximum
|
|
amax = tensor_blocked.abs().amax(dim=-1)
|
|
scale = amax / 127.0
|
|
scale = torch.maximum(scale, torch.tensor(1e-8, device=scale.device, dtype=scale.dtype))
|
|
|
|
# Broadcast scale for division
|
|
scale_broadcast = scale.unsqueeze(-1)
|
|
tensor_scaled = tensor_blocked / scale_broadcast
|
|
|
|
# Clamp and convert to int8
|
|
tensor_scaled = torch.clamp(tensor_scaled, -127.0, 127.0)
|
|
qdata = tensor_scaled.to(torch.int8)
|
|
|
|
# Reshape back to original shape
|
|
qdata = qdata.reshape(tensor.shape)
|
|
return qdata, scale
|
|
|
|
@staticmethod
|
|
def dequantize(qdata, scale, block_size, is_weight=False, orig_dtype=None, output_dtype=None, **kwargs):
|
|
"""
|
|
Dequantize INT8 tensor back to original precision.
|
|
|
|
Args:
|
|
qdata: Quantized INT8 tensor
|
|
scale: Per-block scaling factors
|
|
block_size: Size of quantization blocks
|
|
is_weight: Whether this is a weight tensor
|
|
orig_dtype: Target dtype for dequantization
|
|
|
|
Returns:
|
|
Dequantized tensor in orig_dtype
|
|
"""
|
|
if not qdata.is_contiguous():
|
|
qdata = qdata.contiguous()
|
|
if not scale.is_contiguous():
|
|
scale = scale.contiguous()
|
|
|
|
if is_weight:
|
|
# Weight dequantization
|
|
if _HAS_TRITON_INT8 and qdata.dim() == 2 and qdata.is_cuda:
|
|
try:
|
|
dequant = triton_weight_dequant(qdata, scale, block_size=block_size, output_dtype=output_dtype if output_dtype is not None else orig_dtype)
|
|
return dequant
|
|
except Exception as e:
|
|
logging.warning(f"Triton weight_dequant failed: {e}, falling back to PyTorch")
|
|
raise e
|
|
|
|
# PyTorch fallback
|
|
M, N = qdata.shape
|
|
# Ensure scale has the correct shape for weight dequantization
|
|
expected_scale_shape = (M // block_size, N // block_size)
|
|
if scale.shape != expected_scale_shape:
|
|
expected_numel = (M // block_size) * (N // block_size)
|
|
if scale.numel() == expected_numel:
|
|
scale = scale.reshape(expected_scale_shape)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Weight dequant scale shape mismatch: scale.shape={scale.shape}, expected {expected_scale_shape}"
|
|
)
|
|
qdata_blocked = qdata.reshape(M // block_size, block_size, N // block_size, block_size)
|
|
qdata_blocked = qdata_blocked.permute(0, 2, 1, 3)
|
|
scale_broadcast = scale.unsqueeze(-1).unsqueeze(-1)
|
|
dequant = qdata_blocked.to(orig_dtype) * scale_broadcast
|
|
dequant = dequant.permute(0, 2, 1, 3).reshape(M, N)
|
|
else:
|
|
# Activation dequantization
|
|
if _HAS_TRITON_INT8 and qdata.is_cuda:
|
|
try:
|
|
dequant = triton_act_dequant(qdata, scale, block_size=block_size, output_dtype=output_dtype if output_dtype is not None else orig_dtype)
|
|
return dequant
|
|
except Exception as e:
|
|
logging.warning(f"Triton act_dequant failed: {e}, falling back to PyTorch")
|
|
raise e
|
|
|
|
# PyTorch fallback
|
|
batch_shape = qdata.shape[:-1]
|
|
K = qdata.shape[-1]
|
|
# Ensure scale has the correct shape for activation dequantization
|
|
expected_scale_shape = (*batch_shape, K // block_size)
|
|
if scale.shape != expected_scale_shape:
|
|
expected_numel = 1
|
|
for dim in expected_scale_shape:
|
|
expected_numel *= dim
|
|
if scale.numel() == expected_numel:
|
|
scale = scale.reshape(expected_scale_shape)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Activation dequant scale shape mismatch: scale.shape={scale.shape}, expected {expected_scale_shape}"
|
|
)
|
|
qdata_blocked = qdata.reshape(*batch_shape, K // block_size, block_size)
|
|
scale_broadcast = scale.unsqueeze(-1)
|
|
dequant = qdata_blocked.to(orig_dtype) * scale_broadcast
|
|
dequant = dequant.reshape(qdata.shape)
|
|
|
|
return dequant
|
|
|
|
@classmethod
|
|
def get_plain_tensors(cls, qtensor):
|
|
"""
|
|
Extract raw tensors for computation.
|
|
|
|
Returns:
|
|
Tuple of (qdata, scale, block_size, is_weight)
|
|
"""
|
|
return (
|
|
qtensor._qdata,
|
|
qtensor._layout_params['scale'],
|
|
qtensor._layout_params['block_size'],
|
|
qtensor._layout_params['is_weight']
|
|
)
|
|
|
|
|
|
QUANT_ALGOS = {
|
|
"float8_e4m3fn": {
|
|
"storage_t": torch.float8_e4m3fn,
|
|
"parameters": {"weight_scale", "input_scale"},
|
|
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
|
},
|
|
"int8_blockwise": {
|
|
"storage_t": torch.int8,
|
|
"parameters": {"weight_scale", "input_scale"},
|
|
"comfy_tensor_layout": "BlockWiseINT8Layout",
|
|
"group_size": 128, # Default block size,
|
|
"asymmetric_layout": True,
|
|
},
|
|
}
|
|
|
|
LAYOUTS = {
|
|
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
|
"BlockWiseINT8Layout": BlockWiseINT8Layout,
|
|
}
|
|
|
|
|
|
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
|
def fp8_linear(func, args, kwargs):
|
|
input_tensor = args[0]
|
|
weight = args[1]
|
|
bias = args[2] if len(args) > 2 else None
|
|
|
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
|
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
|
|
|
out_dtype = kwargs.get("out_dtype")
|
|
if out_dtype is None:
|
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
|
|
|
weight_t = plain_weight.t()
|
|
|
|
tensor_2d = False
|
|
if len(plain_input.shape) == 2:
|
|
tensor_2d = True
|
|
plain_input = plain_input.unsqueeze(1)
|
|
|
|
input_shape = plain_input.shape
|
|
if len(input_shape) != 3:
|
|
return None
|
|
|
|
try:
|
|
output = torch._scaled_mm(
|
|
plain_input.reshape(-1, input_shape[2]).contiguous(),
|
|
weight_t,
|
|
bias=bias,
|
|
scale_a=scale_a,
|
|
scale_b=scale_b,
|
|
out_dtype=out_dtype,
|
|
)
|
|
|
|
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
|
output = output[0]
|
|
|
|
if not tensor_2d:
|
|
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
|
|
|
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
output_scale = scale_a * scale_b
|
|
output_params = {
|
|
'scale': output_scale,
|
|
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
|
}
|
|
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
|
else:
|
|
return output
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
|
|
|
# Case 2: DQ Fallback
|
|
if isinstance(weight, QuantizedTensor):
|
|
weight = weight.dequantize()
|
|
if isinstance(input_tensor, QuantizedTensor):
|
|
input_tensor = input_tensor.dequantize()
|
|
|
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
|
|
|
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
|
if out_dtype is None:
|
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
|
|
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
|
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
|
|
|
output = torch._scaled_mm(
|
|
plain_input.contiguous(),
|
|
plain_weight,
|
|
bias=bias,
|
|
scale_a=scale_a,
|
|
scale_b=scale_b,
|
|
out_dtype=out_dtype,
|
|
)
|
|
|
|
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
|
output = output[0]
|
|
return output
|
|
|
|
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
|
def fp8_addmm(func, args, kwargs):
|
|
input_tensor = args[1]
|
|
weight = args[2]
|
|
bias = args[0]
|
|
|
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
|
|
|
a = list(args)
|
|
if isinstance(args[0], QuantizedTensor):
|
|
a[0] = args[0].dequantize()
|
|
if isinstance(args[1], QuantizedTensor):
|
|
a[1] = args[1].dequantize()
|
|
if isinstance(args[2], QuantizedTensor):
|
|
a[2] = args[2].dequantize()
|
|
|
|
return func(*a, **kwargs)
|
|
|
|
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
|
def fp8_mm(func, args, kwargs):
|
|
input_tensor = args[0]
|
|
weight = args[1]
|
|
|
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
|
|
|
a = list(args)
|
|
if isinstance(args[0], QuantizedTensor):
|
|
a[0] = args[0].dequantize()
|
|
if isinstance(args[1], QuantizedTensor):
|
|
a[1] = args[1].dequantize()
|
|
return func(*a, **kwargs)
|
|
|
|
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
|
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
|
def fp8_func(func, args, kwargs):
|
|
input_tensor = args[0]
|
|
if isinstance(input_tensor, QuantizedTensor):
|
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
|
ar = list(args)
|
|
ar[0] = plain_input
|
|
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
# ==============================================================================
|
|
# Block-Wise INT8 Operation Handlers
|
|
# ==============================================================================
|
|
|
|
def _int8_gemm_pytorch_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias=None):
|
|
"""
|
|
PyTorch fallback for INT8 matrix multiplication: dequantize and use standard matmul.
|
|
|
|
Args:
|
|
a_int8: INT8 activations, shape (*batch, K)
|
|
a_scale: Activation scales, shape (*batch, K//block_size)
|
|
b_int8: INT8 weights, shape (N, K) - standard PyTorch weight format
|
|
b_scale: Weight scales, shape (N//block_size, K//block_size)
|
|
block_size: Block size for quantization
|
|
bias: Optional bias vector, shape (N,)
|
|
|
|
Returns:
|
|
Output in float32, shape (*batch, N)
|
|
"""
|
|
K = a_int8.shape[-1]
|
|
batch_shape = a_int8.shape[:-1]
|
|
N = b_int8.shape[0]
|
|
|
|
# Dequantize activations
|
|
# Ensure a_scale has the correct shape - it should be (*batch_shape, K // block_size)
|
|
expected_scale_shape = (*batch_shape, K // block_size)
|
|
if a_scale.shape != expected_scale_shape:
|
|
# Try to reshape if the number of elements matches
|
|
expected_numel = 1
|
|
for dim in expected_scale_shape:
|
|
expected_numel *= dim
|
|
if a_scale.numel() == expected_numel:
|
|
a_scale = a_scale.reshape(expected_scale_shape)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Scale shape mismatch: a_scale.shape={a_scale.shape}, expected {expected_scale_shape}. " +
|
|
f"a_int8.shape={a_int8.shape}, K={K}, block_size={block_size}"
|
|
)
|
|
|
|
a_blocked = a_int8.reshape(*batch_shape, K // block_size, block_size)
|
|
a_scale_broadcast = a_scale.unsqueeze(-1)
|
|
a_fp32 = a_blocked.to(torch.float32) * a_scale_broadcast
|
|
a_fp32 = a_fp32.reshape(*batch_shape, K)
|
|
|
|
# Dequantize weights
|
|
# b_int8 is in (N, K) format (standard weight format), b_scale is in (N//block_size, K//block_size) format
|
|
expected_weight_scale_shape = (N // block_size, K // block_size)
|
|
if b_scale.shape != expected_weight_scale_shape:
|
|
# Try to reshape if the number of elements matches
|
|
expected_weight_numel = (N // block_size) * (K // block_size)
|
|
if b_scale.numel() == expected_weight_numel:
|
|
b_scale = b_scale.reshape(expected_weight_scale_shape)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Weight scale shape mismatch: b_scale.shape={b_scale.shape}, expected {expected_weight_scale_shape}. " +
|
|
f"b_int8.shape={b_int8.shape}, N={N}, K={K}, block_size={block_size}"
|
|
)
|
|
|
|
# Dequantize weight: (N, K) -> blocks -> dequantize -> (N, K)
|
|
b_blocked = b_int8.reshape(N // block_size, block_size, K // block_size, block_size)
|
|
b_blocked = b_blocked.permute(0, 2, 1, 3) # (N//bs, K//bs, bs, bs)
|
|
b_scale_broadcast = b_scale.unsqueeze(-1).unsqueeze(-1)
|
|
b_fp32 = b_blocked.to(torch.float32) * b_scale_broadcast
|
|
b_fp32 = b_fp32.permute(0, 2, 1, 3).reshape(N, K) # Back to (N, K)
|
|
|
|
output = torch.nn.functional.linear(a_fp32, b_fp32, bias)
|
|
return output
|
|
|
|
|
|
def _int8_gemm_triton_or_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias=None, out_quant=False):
|
|
"""
|
|
INT8 matrix multiplication with optional fused bias using Triton kernels or PyTorch fallback.
|
|
|
|
Args:
|
|
a_int8: INT8 activations, shape (*batch, K)
|
|
a_scale: Activation scales, shape (*batch, K//block_size)
|
|
b_int8: INT8 weights, shape (N, K) - standard PyTorch weight format
|
|
b_scale: Weight scales, shape (N//block_size, K//block_size)
|
|
block_size: Block size for quantization
|
|
bias: Optional bias vector, shape (N,)
|
|
out_quant: If True, return quantized output (INT8 + scales) instead of float
|
|
|
|
Returns:
|
|
If out_quant=False: Output in float16/float32, shape (*batch, N)
|
|
If out_quant=True: Tuple of (output_int8, output_scale)
|
|
"""
|
|
K = a_int8.shape[-1]
|
|
batch_shape = a_int8.shape[:-1]
|
|
# b_int8 is weight in (N, K) format (standard PyTorch weight format)
|
|
N = b_int8.shape[0]
|
|
assert b_int8.shape[1] == K, f"Weight shape mismatch: expected b_int8.shape[1]={K}, got {b_int8.shape[1]}"
|
|
|
|
# Try Triton kernel first (only if tensors are on CUDA)
|
|
if _HAS_TRITON_INT8 and a_int8.is_cuda:
|
|
try:
|
|
# int8_gemm/int8_addmm expects: (a, a_s, b, b_s, [bias])
|
|
# a: (*batch, K), a_s: (*batch, K//block_size)
|
|
# b: (N, K), b_s: (N//block_size, K//block_size)
|
|
# Triton kernels transpose b internally
|
|
|
|
# Reshape activations to 2D for int8_gemm
|
|
a_2d = a_int8.reshape(-1, K).contiguous()
|
|
a_scale_2d = a_scale.reshape(-1, a_scale.shape[-1]).contiguous()
|
|
|
|
# Ensure weight tensors are contiguous
|
|
b_int8_c = b_int8.contiguous()
|
|
b_scale_c = b_scale.contiguous()
|
|
|
|
# Call appropriate Triton kernel based on out_quant flag
|
|
if out_quant:
|
|
# Use fused matmul + quantization kernels
|
|
if bias is not None:
|
|
# Fused addmm + quantization
|
|
output_2d, output_scale_2d = triton_int8_addmm_quant(
|
|
a_2d, a_scale_2d, b_int8_c, b_scale_c, bias, out_block_size=block_size
|
|
)
|
|
else:
|
|
# Fused gemm + quantization
|
|
output_2d, output_scale_2d = triton_int8_gemm_quant(
|
|
a_2d, a_scale_2d, b_int8_c, b_scale_c, out_block_size=block_size
|
|
)
|
|
|
|
# Reshape back to original batch shape
|
|
output = output_2d.reshape(*batch_shape, N)
|
|
output_scale = output_scale_2d.reshape(*batch_shape, N // block_size)
|
|
return output, output_scale
|
|
else:
|
|
# Standard float output
|
|
if bias is not None:
|
|
# Use fused addmm kernel
|
|
output_2d = triton_int8_addmm(a_2d, a_scale_2d, b_int8_c, b_scale_c, bias)
|
|
else:
|
|
# Use standard gemm kernel
|
|
output_2d = triton_int8_gemm(a_2d, a_scale_2d, b_int8_c, b_scale_c)
|
|
|
|
# Reshape back to original batch shape
|
|
output = output_2d.reshape(*batch_shape, N)
|
|
return output
|
|
except Exception as e:
|
|
logging.warning(f"Triton int8_gemm/addmm failed: {e}, falling back to PyTorch")
|
|
raise e
|
|
|
|
# Use PyTorch fallback
|
|
fallback_output = _int8_gemm_pytorch_fallback(a_int8, a_scale, b_int8, b_scale, block_size, bias)
|
|
|
|
# If out_quant is requested, quantize the fallback output
|
|
if out_quant:
|
|
# Use PyTorch activation quantization on the output
|
|
from .int8_kernels import act_quant
|
|
try:
|
|
output_int8, output_scale = act_quant(fallback_output, block_size=block_size)
|
|
return output_int8, output_scale
|
|
except:
|
|
# Fallback to CPU quantization if Triton not available
|
|
output_int8, output_scale = BlockWiseINT8Layout._activation_quantize_pytorch(
|
|
fallback_output, block_size
|
|
)
|
|
return output_int8, output_scale
|
|
|
|
return fallback_output
|
|
|
|
|
|
@register_layout_op(torch.ops.aten.linear.default, "BlockWiseINT8Layout")
|
|
def int8_linear(func, args, kwargs):
|
|
"""
|
|
Block-wise INT8 linear operation handler with fused Triton kernel support.
|
|
|
|
Supports:
|
|
- Both quantized input and weight (uses Triton int8_addmm with fused bias)
|
|
- Mixed precision (quantized weight, float input)
|
|
- Optional quantized output via out_dtype and out_quant parameters
|
|
"""
|
|
input_tensor = args[0]
|
|
weight = args[1]
|
|
bias = args[2] if len(args) > 2 else None
|
|
|
|
# Case 1: Both input and weight are quantized
|
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
|
|
# Extract quantized data
|
|
a_int8, a_scale, a_block_size, a_is_weight = BlockWiseINT8Layout.get_plain_tensors(input_tensor)
|
|
b_int8, b_scale, b_block_size, b_is_weight = BlockWiseINT8Layout.get_plain_tensors(weight)
|
|
|
|
# Verify configurations
|
|
assert not a_is_weight, "Input tensor should not be marked as weight"
|
|
assert b_is_weight, "Weight tensor should be marked as weight"
|
|
assert a_block_size == b_block_size, f"Block sizes must match: {a_block_size} vs {b_block_size}"
|
|
|
|
orig_dtype = input_tensor._layout_params['orig_dtype']
|
|
out_dtype = kwargs.get('out_dtype', orig_dtype)
|
|
out_quant = kwargs.get('out_quant', False) # Whether to return quantized output
|
|
|
|
# Weight is already in (N, K) format (standard PyTorch weight format)
|
|
# Pass out_quant to _int8_gemm_triton_or_fallback for fused matmul+quant
|
|
result = _int8_gemm_triton_or_fallback(
|
|
a_int8, a_scale, b_int8, b_scale, a_block_size,
|
|
bias=bias, out_quant=out_quant
|
|
)
|
|
|
|
# Handle quantized vs float output
|
|
if out_quant:
|
|
# Result is (output_int8, output_scale) tuple
|
|
output_int8, output_scale = result
|
|
|
|
# Wrap in QuantizedTensor
|
|
layout_params = {
|
|
'scale': output_scale,
|
|
'block_size': a_block_size,
|
|
'is_weight': False,
|
|
'orig_dtype': out_dtype
|
|
}
|
|
return QuantizedTensor(output_int8, "BlockWiseINT8Layout", layout_params)
|
|
else:
|
|
# Result is float tensor
|
|
output = result
|
|
# Convert to target dtype if needed
|
|
if output.dtype != out_dtype:
|
|
output = output.to(out_dtype)
|
|
return output
|
|
|
|
# Case 2: Fallback - dequantize and use standard linear
|
|
if isinstance(weight, QuantizedTensor):
|
|
weight = weight.dequantize()
|
|
if isinstance(input_tensor, QuantizedTensor):
|
|
input_tensor = input_tensor.dequantize()
|
|
|
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
|
|
|
|
|
@register_layout_op(torch.ops.aten.mm.default, "BlockWiseINT8Layout")
|
|
def int8_mm(func, args, kwargs):
|
|
"""Block-wise INT8 matrix multiplication handler with Triton kernel support."""
|
|
input_tensor = args[0]
|
|
weight = args[1]
|
|
|
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
a_int8, a_scale, a_block_size, a_is_weight = BlockWiseINT8Layout.get_plain_tensors(input_tensor)
|
|
b_int8, b_scale, b_block_size, b_is_weight = BlockWiseINT8Layout.get_plain_tensors(weight)
|
|
|
|
assert a_block_size == b_block_size, f"Block sizes must match: {a_block_size} vs {b_block_size}"
|
|
|
|
# Note: For mm, we expect both to be 2D
|
|
# If input is marked as weight (2D blocking), we need different logic
|
|
# For simplicity, dequantize if configurations don't match expected pattern
|
|
if a_is_weight or not b_is_weight:
|
|
logging.warning("INT8 mm: Unexpected tensor configurations, falling back to dequantization")
|
|
return func(input_tensor.dequantize(), weight.dequantize())
|
|
|
|
orig_dtype = input_tensor._layout_params['orig_dtype']
|
|
out_dtype = kwargs.get('out_dtype', orig_dtype)
|
|
out_quant = kwargs.get('out_quant', False) # Whether to return quantized output (default: True)
|
|
|
|
# Check if weight needs to be transposed to (N, K) format
|
|
# For mm: input is (M, K), weight should be (N, K) for the kernel
|
|
K = a_int8.shape[-1]
|
|
if b_int8.shape[0] == K and b_int8.shape[1] != K:
|
|
# Weight is in (K, N) format (transposed), transpose back to (N, K)
|
|
b_int8 = b_int8.t().contiguous()
|
|
b_scale = b_scale.t().contiguous()
|
|
|
|
result = _int8_gemm_triton_or_fallback(
|
|
a_int8, a_scale, b_int8, b_scale, a_block_size,
|
|
bias=None, out_quant=out_quant
|
|
)
|
|
|
|
# Handle quantized vs float output
|
|
if out_quant:
|
|
# Result is (output_int8, output_scale) tuple
|
|
output_int8, output_scale = result
|
|
|
|
# Wrap in QuantizedTensor
|
|
layout_params = {
|
|
'scale': output_scale,
|
|
'block_size': a_block_size,
|
|
'is_weight': False,
|
|
'orig_dtype': out_dtype
|
|
}
|
|
return QuantizedTensor(output_int8, "BlockWiseINT8Layout", layout_params)
|
|
else:
|
|
# Result is float tensor
|
|
output = result
|
|
# Convert to target dtype if needed
|
|
if output.dtype != out_dtype:
|
|
output = output.to(out_dtype)
|
|
return output
|
|
|
|
# Fallback
|
|
a = list(args)
|
|
if isinstance(args[0], QuantizedTensor):
|
|
a[0] = args[0].dequantize()
|
|
if isinstance(args[1], QuantizedTensor):
|
|
a[1] = args[1].dequantize()
|
|
return func(*a, **kwargs)
|
|
|
|
|
|
@register_layout_op(torch.ops.aten.addmm.default, "BlockWiseINT8Layout")
|
|
def int8_addmm(func, args, kwargs):
|
|
"""
|
|
Block-wise INT8 addmm operation handler with fused Triton kernel support.
|
|
addmm: out = beta * input + alpha * (mat1 @ mat2)
|
|
|
|
This uses the fused int8_addmm kernel which combines matmul and bias addition
|
|
in a single pass for better performance.
|
|
|
|
Args:
|
|
args[0]: bias tensor
|
|
args[1]: mat1 (input)
|
|
args[2]: mat2 (weight)
|
|
"""
|
|
bias = args[0]
|
|
input_tensor = args[1]
|
|
weight = args[2]
|
|
|
|
# Case 1: Both input and weight are quantized
|
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
# Extract quantized data
|
|
a_int8, a_scale, a_block_size, a_is_weight = BlockWiseINT8Layout.get_plain_tensors(input_tensor)
|
|
b_int8, b_scale, b_block_size, b_is_weight = BlockWiseINT8Layout.get_plain_tensors(weight)
|
|
|
|
# Verify configurations
|
|
assert a_block_size == b_block_size, f"Block sizes must match: {a_block_size} vs {b_block_size}"
|
|
|
|
orig_dtype = input_tensor._layout_params['orig_dtype']
|
|
out_dtype = kwargs.get('out_dtype', orig_dtype)
|
|
out_quant = kwargs.get('out_quant', False) # Whether to return quantized output
|
|
|
|
# PyTorch's F.linear internally calls addmm(bias, input, weight.t())
|
|
# So weight arrives in (K, N) format (transposed), need to transpose back to (N, K)
|
|
# Check if weight is transposed by comparing dimensions with input
|
|
K = a_int8.shape[-1]
|
|
if b_is_weight and b_int8.shape[0] == K:
|
|
# Weight is in (K, N) format (transposed), transpose back to (N, K)
|
|
# The transpose handler also transposed the scale, so we need to transpose it back too
|
|
b_int8 = b_int8.t().contiguous()
|
|
b_scale = b_scale.t().contiguous()
|
|
|
|
# Use fused Triton kernel (combines matmul + bias + optional quant)
|
|
result = _int8_gemm_triton_or_fallback(
|
|
a_int8, a_scale, b_int8, b_scale, a_block_size,
|
|
bias=bias, out_quant=out_quant
|
|
)
|
|
|
|
# Handle quantized vs float output
|
|
if out_quant:
|
|
# Result is (output_int8, output_scale) tuple
|
|
output_int8, output_scale = result
|
|
|
|
# Wrap in QuantizedTensor
|
|
layout_params = {
|
|
'scale': output_scale,
|
|
'block_size': a_block_size,
|
|
'is_weight': False,
|
|
'orig_dtype': out_dtype
|
|
}
|
|
return QuantizedTensor(output_int8, "BlockWiseINT8Layout", layout_params)
|
|
else:
|
|
# Result is float tensor
|
|
output = result
|
|
# Convert to target dtype if needed
|
|
if output.dtype != out_dtype:
|
|
output = output.to(out_dtype)
|
|
return output
|
|
|
|
# Fallback: dequantize and use standard addmm
|
|
a = list(args)
|
|
if isinstance(args[0], QuantizedTensor):
|
|
a[0] = args[0].dequantize()
|
|
if isinstance(args[1], QuantizedTensor):
|
|
a[1] = args[1].dequantize()
|
|
if isinstance(args[2], QuantizedTensor):
|
|
a[2] = args[2].dequantize()
|
|
|
|
return func(*a, **kwargs)
|
|
|
|
|
|
@register_layout_op(torch.ops.aten.view.default, "BlockWiseINT8Layout")
|
|
def int8_view(func, args, kwargs):
|
|
"""Handle view operations for INT8 tensors."""
|
|
input_tensor = args[0]
|
|
if isinstance(input_tensor, QuantizedTensor):
|
|
# For view, we need to be careful with block structure
|
|
# For safety, we'll allow these ops but note that they might break block alignment
|
|
plain_input = input_tensor._qdata
|
|
ar = list(args)
|
|
ar[0] = plain_input
|
|
transformed = func(*ar, **kwargs)
|
|
|
|
# Return new QuantizedTensor with same layout params
|
|
# Note: This assumes the transformation preserves block structure
|
|
return QuantizedTensor(transformed, "BlockWiseINT8Layout", input_tensor._layout_params)
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_layout_op(torch.ops.aten.t.default, "BlockWiseINT8Layout")
|
|
def int8_transpose(func, args, kwargs):
|
|
"""Handle transpose operations for INT8 tensors."""
|
|
input_tensor = args[0]
|
|
if isinstance(input_tensor, QuantizedTensor):
|
|
# Transpose the quantized data
|
|
plain_input = input_tensor._qdata
|
|
ar = list(args)
|
|
ar[0] = plain_input
|
|
transformed = func(*ar, **kwargs)
|
|
|
|
# For weight tensors, we need to transpose the scale tensor as well
|
|
new_layout_params = input_tensor._layout_params.copy()
|
|
if new_layout_params.get('is_weight', False):
|
|
# Transpose the scale tensor to match the transposed weight
|
|
new_layout_params['scale'] = new_layout_params['scale'].t().contiguous()
|
|
|
|
# Return new QuantizedTensor with updated layout params
|
|
return QuantizedTensor(transformed, "BlockWiseINT8Layout", new_layout_params)
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_layout_op(torch.ops.aten.transpose.int, "BlockWiseINT8Layout")
|
|
def int8_transpose_int(func, args, kwargs):
|
|
"""
|
|
Handle general transpose operations for INT8 tensors.
|
|
|
|
torch.transpose(input, dim0, dim1) swaps two dimensions.
|
|
|
|
For BlockWiseINT8Layout:
|
|
- Activations: quantized along last dimension, scale shape is (*batch_dims, K//block_size)
|
|
If we swap the last dimension, we need to adjust scale handling
|
|
- Weights: quantized in 2D blocks (M, N), scale shape is (M//block_size, N//block_size)
|
|
If we swap dimensions on a 2D weight, transpose the scale tensor too
|
|
"""
|
|
input_tensor = args[0]
|
|
dim0 = args[1] if len(args) > 1 else kwargs.get('dim0', 0)
|
|
dim1 = args[2] if len(args) > 2 else kwargs.get('dim1', 1)
|
|
|
|
if isinstance(input_tensor, QuantizedTensor):
|
|
# Transpose the quantized data
|
|
plain_input = input_tensor._qdata
|
|
ar = list(args)
|
|
ar[0] = plain_input
|
|
transformed = func(*ar, **kwargs)
|
|
|
|
# Copy layout params
|
|
new_layout_params = input_tensor._layout_params.copy()
|
|
is_weight = new_layout_params.get('is_weight', False)
|
|
|
|
# Normalize dimensions to positive indices
|
|
ndim = plain_input.ndim
|
|
if dim0 < 0:
|
|
dim0 = ndim + dim0
|
|
if dim1 < 0:
|
|
dim1 = ndim + dim1
|
|
|
|
# Handle scale tensor transposition
|
|
if is_weight:
|
|
# For weight tensors (2D with block-wise quantization in both dims)
|
|
# If we're transposing the two dimensions of a 2D tensor, transpose scales too
|
|
if ndim == 2 and set([dim0, dim1]) == {0, 1}:
|
|
# Transposing a 2D weight tensor (M, N) -> (N, M)
|
|
# Scale goes from (M//block_size, N//block_size) -> (N//block_size, M//block_size)
|
|
new_layout_params['scale'] = new_layout_params['scale'].t().contiguous()
|
|
else:
|
|
# For higher dimensional weight tensors or partial transposes,
|
|
# we may need more complex scale handling
|
|
# For now, log a warning as this is an uncommon case
|
|
logging.warning(
|
|
f"Transpose on weight tensor with dims ({dim0}, {dim1}) and shape {plain_input.shape}. "
|
|
f"Scale tensor may need adjustment for correct behavior."
|
|
)
|
|
else:
|
|
# For activation tensors, block-wise quantization is along last dimension
|
|
# If we're swapping the last dimension, this changes the quantization structure
|
|
last_dim = ndim - 1
|
|
if dim0 == last_dim or dim1 == last_dim:
|
|
# The last dimension is being moved, which affects quantization blocks
|
|
# This is a complex case - for safety, we could:
|
|
# 1. Dequantize, transpose, requantize (safest but slower)
|
|
# 2. Try to adjust scale tensor (complex, error-prone)
|
|
# For now, log a warning and proceed with transposing the scale tensor
|
|
# The scale tensor dimensions follow the input dimensions except the last
|
|
# which is divided by block_size
|
|
|
|
# Determine how to transpose the scale tensor
|
|
# Scale shape is (*batch_dims, K//block_size) where K is the last dim of input
|
|
# When we transpose input dims, we need to transpose scale dims accordingly
|
|
# But the last scale dim always corresponds to the quantization blocks
|
|
|
|
# Simple heuristic: if transposing involves last dim and input has 3+ dims,
|
|
# we transpose the corresponding scale dimensions
|
|
scale = new_layout_params['scale']
|
|
if scale.ndim >= 2:
|
|
# Map input dimensions to scale dimensions
|
|
# Scale has shape (*batch_dims, K//block_size)
|
|
# If input has shape (*batch_dims, K), scale maps batch_dims directly
|
|
# and last dim is K//block_size
|
|
|
|
# For transpose, if we swap dims d0 and d1 in input:
|
|
# - If d1 is last_dim (K), then in scale it's still last (K//block_size)
|
|
# - If d0 is last_dim, same applies
|
|
# - If neither is last_dim, transpose applies to batch dimensions
|
|
|
|
if dim1 == last_dim:
|
|
# Swapping some batch dim with the last dim
|
|
# In scale, this means swapping that batch dim with last scale dim
|
|
scale_dim0 = dim0 # Same batch dimension
|
|
scale_dim1 = scale.ndim - 1 # Last dim of scale (K//block_size)
|
|
new_layout_params['scale'] = scale.transpose(scale_dim0, scale_dim1).contiguous()
|
|
elif dim0 == last_dim:
|
|
# Swapping last dim with some batch dim
|
|
scale_dim0 = scale.ndim - 1 # Last dim of scale
|
|
scale_dim1 = dim1 # Same batch dimension
|
|
new_layout_params['scale'] = scale.transpose(scale_dim0, scale_dim1).contiguous()
|
|
else:
|
|
# Swapping two batch dimensions (not involving last dim)
|
|
# Transpose the same dimensions in scale
|
|
new_layout_params['scale'] = scale.transpose(dim0, dim1).contiguous()
|
|
else:
|
|
logging.warning(
|
|
f"Transpose involves last dimension but scale tensor has shape {scale.shape}. "
|
|
f"Scale tensor may need adjustment."
|
|
)
|
|
else:
|
|
# Transposing batch dimensions that don't affect the quantized dimension
|
|
# Transpose the same dimensions in scale tensor
|
|
scale = new_layout_params['scale']
|
|
if scale.ndim > max(dim0, dim1):
|
|
new_layout_params['scale'] = scale.transpose(dim0, dim1).contiguous()
|
|
|
|
# Return new QuantizedTensor with updated layout params
|
|
return QuantizedTensor(transformed, "BlockWiseINT8Layout", new_layout_params)
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_layout_op(torch.ops.aten.gelu.default, "BlockWiseINT8Layout")
|
|
def int8_gelu(func, args, kwargs):
|
|
"""
|
|
Block-wise INT8 GELU activation handler with fused Triton kernel support.
|
|
|
|
Supports quantized input -> GELU -> quantized output in a single fused kernel.
|
|
This avoids materializing full-precision intermediate results.
|
|
"""
|
|
input_tensor = args[0]
|
|
|
|
# Case 1: Input is quantized - use fused kernel
|
|
if isinstance(input_tensor, QuantizedTensor):
|
|
# Extract quantized data
|
|
qdata, scale, block_size, is_weight = BlockWiseINT8Layout.get_plain_tensors(input_tensor)
|
|
|
|
orig_dtype = input_tensor._layout_params['orig_dtype']
|
|
|
|
# Determine if we should use Triton kernel
|
|
if _HAS_TRITON_INT8 and qdata.is_cuda:
|
|
try:
|
|
# Import the Triton kernel
|
|
from .int8_kernels import int8_gelu as triton_int8_gelu
|
|
|
|
# Call fused kernel
|
|
output_qdata, output_scale = triton_int8_gelu(qdata, scale, block_size=block_size)
|
|
|
|
# Wrap result in QuantizedTensor
|
|
layout_params = {
|
|
'scale': output_scale.to(torch.float32),
|
|
'block_size': block_size,
|
|
'is_weight': False, # Output is always activation format
|
|
'orig_dtype': orig_dtype
|
|
}
|
|
return QuantizedTensor(output_qdata, "BlockWiseINT8Layout", layout_params)
|
|
|
|
except Exception as e:
|
|
logging.warning(f"Triton int8_gelu failed: {e}, falling back to dequantization")
|
|
# Fall through to dequantization fallback
|
|
|
|
# Fallback: dequantize, apply GELU, quantize
|
|
fp_input = input_tensor.dequantize()
|
|
fp_output = torch.nn.functional.gelu(fp_input)
|
|
|
|
# Quantize output
|
|
output_qdata, output_layout_params = BlockWiseINT8Layout.quantize(
|
|
fp_output,
|
|
block_size=block_size,
|
|
is_weight=False
|
|
)
|
|
output_layout_params['orig_dtype'] = orig_dtype
|
|
|
|
return QuantizedTensor(output_qdata, "BlockWiseINT8Layout", output_layout_params)
|
|
|
|
# Case 2: Input is not quantized - use standard GELU
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_layout_op(torch.ops.aten.add_.Tensor, "BlockWiseINT8Layout")
|
|
def int8_add_(func, args, kwargs):
|
|
"""
|
|
Block-wise INT8 in-place addition handler for LoRA application.
|
|
|
|
This operation is typically used when applying LoRA to weight matrices.
|
|
Since speed is not critical for this operation:
|
|
- If target is a weight: dequantize, add, then requantize as weight
|
|
- Otherwise: dequantize and fallback to regular addition
|
|
|
|
Args:
|
|
args[0]: Target tensor (self) to be modified in-place
|
|
args[1]: Tensor to add
|
|
"""
|
|
target = args[0]
|
|
|
|
if isinstance(target, QuantizedTensor):
|
|
# Extract quantization parameters
|
|
_, _, block_size, is_weight = BlockWiseINT8Layout.get_plain_tensors(target)
|
|
|
|
# Only handle the weight case specially
|
|
if is_weight:
|
|
other = args[1]
|
|
orig_dtype = target._layout_params['orig_dtype']
|
|
|
|
# Dequantize target
|
|
target_fp = target.dequantize()
|
|
|
|
# Dequantize other if it's also quantized
|
|
if isinstance(other, QuantizedTensor):
|
|
other_fp = other.dequantize()
|
|
else:
|
|
other_fp = other
|
|
|
|
# Perform addition
|
|
result_fp = target_fp + other_fp
|
|
|
|
# Requantize as weight
|
|
result_qdata, result_layout_params = BlockWiseINT8Layout.quantize(
|
|
result_fp,
|
|
block_size=block_size,
|
|
is_weight=True
|
|
)
|
|
result_layout_params['orig_dtype'] = orig_dtype
|
|
|
|
# Update target in-place by copying the new quantized data
|
|
target._qdata.copy_(result_qdata)
|
|
target._layout_params['scale'].copy_(result_layout_params['scale'])
|
|
return target
|
|
|
|
# For non-weight tensors or non-quantized tensors, use standard fallback
|
|
return QuantizedTensor._dequant_and_fallback(func, args, kwargs)
|
|
|
|
|
|
@register_layout_op(torch.ops.aten.to.dtype, "BlockWiseINT8Layout")
|
|
def int8_to_dtype(func, args, kwargs):
|
|
"""
|
|
Block-wise INT8 dtype conversion handler.
|
|
|
|
This operation handles .to(dtype) calls on quantized tensors.
|
|
- If converting to torch.int8, do nothing (already in INT8 format)
|
|
- Otherwise, dequantize and fallback
|
|
|
|
Args:
|
|
args[0]: Input tensor
|
|
args[1]: Target dtype
|
|
"""
|
|
input_tensor = args[0]
|
|
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype', None)
|
|
|
|
if isinstance(input_tensor, QuantizedTensor):
|
|
# If target dtype is int8, the tensor is already in INT8 format
|
|
if target_dtype == torch.int8:
|
|
# No conversion needed, return as-is
|
|
return input_tensor
|
|
|
|
# For any other dtype or non-quantized tensors, use standard fallback
|
|
return QuantizedTensor._dequant_and_fallback(func, args, kwargs)
|