mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 14:02:37 +08:00
bugfix: fix typo in apply_directory for custom_nodes_directory
allow for PATH style ';' delimited custom_node directories.
change delimiter type for seperate folders per platform.
feat(API-nodes): move Rodin3D nodes to new client; removed old api client.py (#10645)
Fix qwen controlnet regression. (#10657)
Enable pinned memory by default on Nvidia. (#10656)
Removed the --fast pinned_memory flag.
You can use --disable-pinned-memory to disable it. Please report if it
causes any issues.
Pinned mem also seems to work on AMD. (#10658)
Remove environment variable.
Removed environment variable fallback for custom nodes directory.
Update documentation for custom nodes directory
Clarified documentation on custom nodes directory argument, removed documentation on environment variable
Clarify release cycle. (#10667)
Tell users they need to upload their logs in bug reports. (#10671)
mm: guard against double pin and unpin explicitly (#10672)
As commented, if you let cuda be the one to detect double pin/unpinning
it actually creates an asyc GPU error.
Only unpin tensor if it was pinned by ComfyUI (#10677)
Make ScaleROPE node work on Flux. (#10686)
Add logging for model unloading. (#10692)
Unload weights if vram usage goes up between runs. (#10690)
ops: Put weight cast on the offload stream (#10697)
This needs to be on the offload stream. This reproduced a black screen
with low resolution images on a slow bus when using FP8.
Update CI workflow to remove dead macOS runner. (#10704)
* Update CI workflow to remove dead macOS runner.
* revert
* revert
Don't pin tensor if not a torch.nn.parameter.Parameter (#10718)
Update README.md for Intel Arc GPU installation, remove IPEX (#10729)
IPEX is no longer needed for Intel Arc GPUs. Removing instruction to setup ipex.
mm/mp: always unload re-used but modified models (#10724)
The partial unloader path in model re-use flow skips straight to the
actual unload without any check of the patching UUID. This means that
if you do an upscale flow with a model patch on an existing model, it
will not apply your patchings.
Fix by delaying the partial_unload until after the uuid checks. This
is done by making partial_unload a model of partial_load where extra_mem
is -ve.
qwen: reduce VRAM usage (#10725)
Clean up a bunch of stacked and no-longer-needed tensors on the QWEN
VRAM peak (currently FFN).
With this I go from OOMing at B=37x1328x1328 to being able to
succesfully run B=47 (RTX5090).
Update Python 3.14 compatibility notes in README (#10730)
Quantized Ops fixes (#10715)
* offload support, bug fixes, remove mixins
* add readme
add PR template for API-Nodes (#10736)
feat: add create_time dict to prompt field in /history and /queue (#10741)
flux: reduce VRAM usage (#10737)
Cleanup a bunch of stack tensors on Flux. This take me from B=19 to B=22
for 1600x1600 on RTX5090.
Better instructions for the portable. (#10743)
Use same code for chroma and flux blocks so that optimizations are shared. (#10746)
Fix custom nodes import error. (#10747)
This should fix the import errors but will break if the custom nodes actually try to use the class.
revert import reordering
revert imports pt 2
Add left padding support to tokenizers. (#10753)
chore(api-nodes): mark OpenAIDalle2 and OpenAIDalle3 nodes as deprecated (#10757)
Revert "chore(api-nodes): mark OpenAIDalle2 and OpenAIDalle3 nodes as deprecated (#10757)" (#10759)
This reverts commit 9a02382568.
Change ROCm nightly install command to 7.1 (#10764)
546 lines
19 KiB
Python
546 lines
19 KiB
Python
import torch
|
|
import logging
|
|
from typing import Tuple, Dict
|
|
|
|
_LAYOUT_REGISTRY = {}
|
|
_GENERIC_UTILS = {}
|
|
|
|
|
|
def register_layout_op(torch_op, layout_type):
|
|
"""
|
|
Decorator to register a layout-specific operation handler.
|
|
Args:
|
|
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
|
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
|
Example:
|
|
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
|
def fp8_linear(func, args, kwargs):
|
|
# FP8-specific linear implementation
|
|
...
|
|
"""
|
|
def decorator(handler_func):
|
|
if torch_op not in _LAYOUT_REGISTRY:
|
|
_LAYOUT_REGISTRY[torch_op] = {}
|
|
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
|
return handler_func
|
|
return decorator
|
|
|
|
|
|
def register_generic_util(torch_op):
|
|
"""
|
|
Decorator to register a generic utility that works for all layouts.
|
|
Args:
|
|
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
|
|
|
Example:
|
|
@register_generic_util(torch.ops.aten.detach.default)
|
|
def generic_detach(func, args, kwargs):
|
|
# Works for any layout
|
|
...
|
|
"""
|
|
def decorator(handler_func):
|
|
_GENERIC_UTILS[torch_op] = handler_func
|
|
return handler_func
|
|
return decorator
|
|
|
|
|
|
def _get_layout_from_args(args):
|
|
for arg in args:
|
|
if isinstance(arg, QuantizedTensor):
|
|
return arg._layout_type
|
|
elif isinstance(arg, (list, tuple)):
|
|
for item in arg:
|
|
if isinstance(item, QuantizedTensor):
|
|
return item._layout_type
|
|
return None
|
|
|
|
|
|
def _move_layout_params_to_device(params, device):
|
|
new_params = {}
|
|
for k, v in params.items():
|
|
if isinstance(v, torch.Tensor):
|
|
new_params[k] = v.to(device=device)
|
|
else:
|
|
new_params[k] = v
|
|
return new_params
|
|
|
|
|
|
def _copy_layout_params(params):
|
|
new_params = {}
|
|
for k, v in params.items():
|
|
if isinstance(v, torch.Tensor):
|
|
new_params[k] = v.clone()
|
|
else:
|
|
new_params[k] = v
|
|
return new_params
|
|
|
|
def _copy_layout_params_inplace(src, dst, non_blocking=False):
|
|
for k, v in src.items():
|
|
if isinstance(v, torch.Tensor):
|
|
dst[k].copy_(v, non_blocking=non_blocking)
|
|
else:
|
|
dst[k] = v
|
|
|
|
class QuantizedLayout:
|
|
"""
|
|
Base class for quantization layouts.
|
|
|
|
A layout encapsulates the format-specific logic for quantization/dequantization
|
|
and provides a uniform interface for extracting raw tensors needed for computation.
|
|
|
|
New quantization formats should subclass this and implement the required methods.
|
|
"""
|
|
@classmethod
|
|
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
|
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
|
|
|
@staticmethod
|
|
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
|
raise NotImplementedError("TensorLayout must implement dequantize()")
|
|
|
|
@classmethod
|
|
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
|
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
|
|
|
|
|
class QuantizedTensor(torch.Tensor):
|
|
"""
|
|
Universal quantized tensor that works with any layout.
|
|
|
|
This tensor subclass uses a pluggable layout system to support multiple
|
|
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
|
|
|
The layout_type determines format-specific behavior, while common operations
|
|
(detach, clone, to) are handled generically.
|
|
|
|
Attributes:
|
|
_qdata: The quantized tensor data
|
|
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
|
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
|
"""
|
|
|
|
@staticmethod
|
|
def __new__(cls, qdata, layout_type, layout_params):
|
|
"""
|
|
Create a quantized tensor.
|
|
|
|
Args:
|
|
qdata: The quantized data tensor
|
|
layout_type: Layout class (subclass of QuantizedLayout)
|
|
layout_params: Dict with layout-specific parameters
|
|
"""
|
|
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
|
|
|
def __init__(self, qdata, layout_type, layout_params):
|
|
self._qdata = qdata
|
|
self._layout_type = layout_type
|
|
self._layout_params = layout_params
|
|
|
|
def __repr__(self):
|
|
layout_name = self._layout_type
|
|
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
|
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
|
|
|
@property
|
|
def layout_type(self):
|
|
return self._layout_type
|
|
|
|
def __tensor_flatten__(self):
|
|
"""
|
|
Tensor flattening protocol for proper device movement.
|
|
"""
|
|
inner_tensors = ["_qdata"]
|
|
ctx = {
|
|
"layout_type": self._layout_type,
|
|
}
|
|
|
|
tensor_params = {}
|
|
non_tensor_params = {}
|
|
for k, v in self._layout_params.items():
|
|
if isinstance(v, torch.Tensor):
|
|
tensor_params[k] = v
|
|
else:
|
|
non_tensor_params[k] = v
|
|
|
|
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
|
ctx["non_tensor_params"] = non_tensor_params
|
|
|
|
for k, v in tensor_params.items():
|
|
attr_name = f"_layout_param_{k}"
|
|
object.__setattr__(self, attr_name, v)
|
|
inner_tensors.append(attr_name)
|
|
|
|
return inner_tensors, ctx
|
|
|
|
@staticmethod
|
|
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
|
"""
|
|
Tensor unflattening protocol for proper device movement.
|
|
Reconstructs the QuantizedTensor after device movement.
|
|
"""
|
|
layout_type = ctx["layout_type"]
|
|
layout_params = dict(ctx["non_tensor_params"])
|
|
|
|
for key in ctx["tensor_param_keys"]:
|
|
attr_name = f"_layout_param_{key}"
|
|
layout_params[key] = inner_tensors[attr_name]
|
|
|
|
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
|
|
|
|
@classmethod
|
|
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
|
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
|
return cls(qdata, layout_type, layout_params)
|
|
|
|
def dequantize(self) -> torch.Tensor:
|
|
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
|
|
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
|
if func in _GENERIC_UTILS:
|
|
return _GENERIC_UTILS[func](func, args, kwargs)
|
|
|
|
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
|
layout_type = _get_layout_from_args(args)
|
|
if layout_type and func in _LAYOUT_REGISTRY:
|
|
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
|
if handler:
|
|
return handler(func, args, kwargs)
|
|
|
|
# Step 3: Fallback to dequantization
|
|
if isinstance(args[0] if args else None, QuantizedTensor):
|
|
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
|
return cls._dequant_and_fallback(func, args, kwargs)
|
|
|
|
@classmethod
|
|
def _dequant_and_fallback(cls, func, args, kwargs):
|
|
def dequant_arg(arg):
|
|
if isinstance(arg, QuantizedTensor):
|
|
return arg.dequantize()
|
|
elif isinstance(arg, (list, tuple)):
|
|
return type(arg)(dequant_arg(a) for a in arg)
|
|
return arg
|
|
|
|
new_args = dequant_arg(args)
|
|
new_kwargs = dequant_arg(kwargs)
|
|
return func(*new_args, **new_kwargs)
|
|
|
|
|
|
# ==============================================================================
|
|
# Generic Utilities (Layout-Agnostic Operations)
|
|
# ==============================================================================
|
|
|
|
def _create_transformed_qtensor(qt, transform_fn):
|
|
new_data = transform_fn(qt._qdata)
|
|
new_params = _copy_layout_params(qt._layout_params)
|
|
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
|
|
|
|
|
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
|
if target_dtype is not None and target_dtype != qt.dtype:
|
|
logging.warning(
|
|
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
|
f"but not supported for quantized tensors. Ignoring dtype."
|
|
)
|
|
|
|
if target_layout is not None and target_layout != torch.strided:
|
|
logging.warning(
|
|
f"QuantizedTensor: layout change requested to {target_layout}, "
|
|
f"but not supported. Ignoring layout."
|
|
)
|
|
|
|
# Handle device transfer
|
|
current_device = qt._qdata.device
|
|
if target_device is not None:
|
|
# Normalize device for comparison
|
|
if isinstance(target_device, str):
|
|
target_device = torch.device(target_device)
|
|
if isinstance(current_device, str):
|
|
current_device = torch.device(current_device)
|
|
|
|
if target_device != current_device:
|
|
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
|
new_q_data = qt._qdata.to(device=target_device)
|
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
|
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
|
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
|
return new_qt
|
|
|
|
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
|
return qt
|
|
|
|
|
|
@register_generic_util(torch.ops.aten.detach.default)
|
|
def generic_detach(func, args, kwargs):
|
|
"""Detach operation - creates a detached copy of the quantized tensor."""
|
|
qt = args[0]
|
|
if isinstance(qt, QuantizedTensor):
|
|
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_generic_util(torch.ops.aten.clone.default)
|
|
def generic_clone(func, args, kwargs):
|
|
"""Clone operation - creates a deep copy of the quantized tensor."""
|
|
qt = args[0]
|
|
if isinstance(qt, QuantizedTensor):
|
|
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_generic_util(torch.ops.aten._to_copy.default)
|
|
def generic_to_copy(func, args, kwargs):
|
|
"""Device/dtype transfer operation - handles .to(device) calls."""
|
|
qt = args[0]
|
|
if isinstance(qt, QuantizedTensor):
|
|
return _handle_device_transfer(
|
|
qt,
|
|
target_device=kwargs.get('device', None),
|
|
target_dtype=kwargs.get('dtype', None),
|
|
op_name="_to_copy"
|
|
)
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
|
def generic_to_dtype_layout(func, args, kwargs):
|
|
"""Handle .to(device) calls using the dtype_layout variant."""
|
|
qt = args[0]
|
|
if isinstance(qt, QuantizedTensor):
|
|
return _handle_device_transfer(
|
|
qt,
|
|
target_device=kwargs.get('device', None),
|
|
target_dtype=kwargs.get('dtype', None),
|
|
target_layout=kwargs.get('layout', None),
|
|
op_name="to"
|
|
)
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_generic_util(torch.ops.aten.copy_.default)
|
|
def generic_copy_(func, args, kwargs):
|
|
qt_dest = args[0]
|
|
src = args[1]
|
|
non_blocking = args[2] if len(args) > 2 else False
|
|
if isinstance(qt_dest, QuantizedTensor):
|
|
if isinstance(src, QuantizedTensor):
|
|
# Copy from another quantized tensor
|
|
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
|
qt_dest._layout_type = src._layout_type
|
|
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
|
else:
|
|
# Copy from regular tensor - just copy raw data
|
|
qt_dest._qdata.copy_(src)
|
|
return qt_dest
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
|
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
|
return True
|
|
|
|
|
|
@register_generic_util(torch.ops.aten.empty_like.default)
|
|
def generic_empty_like(func, args, kwargs):
|
|
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
|
|
qt = args[0]
|
|
if isinstance(qt, QuantizedTensor):
|
|
# Create empty tensor with same shape and dtype as the quantized data
|
|
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
|
|
new_qdata = torch.empty_like(qt._qdata, **kwargs)
|
|
|
|
# Handle device transfer for layout params
|
|
target_device = kwargs.get('device', new_qdata.device)
|
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
|
|
|
# Update orig_dtype if dtype is specified
|
|
new_params['orig_dtype'] = hp_dtype
|
|
|
|
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
|
|
return func(*args, **kwargs)
|
|
|
|
# ==============================================================================
|
|
# FP8 Layout + Operation Handlers
|
|
# ==============================================================================
|
|
class TensorCoreFP8Layout(QuantizedLayout):
|
|
"""
|
|
Storage format:
|
|
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
|
- scale: Scalar tensor (float32) for dequantization
|
|
- orig_dtype: Original dtype before quantization (for casting back)
|
|
"""
|
|
@classmethod
|
|
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
|
|
orig_dtype = tensor.dtype
|
|
|
|
if scale is None:
|
|
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
|
|
|
if not isinstance(scale, torch.Tensor):
|
|
scale = torch.tensor(scale)
|
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
|
|
|
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
|
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
|
|
# lp_amax = torch.finfo(dtype).max
|
|
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
|
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
|
|
|
layout_params = {
|
|
'scale': scale,
|
|
'orig_dtype': orig_dtype
|
|
}
|
|
return qdata, layout_params
|
|
|
|
@staticmethod
|
|
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
|
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
|
return plain_tensor * scale
|
|
|
|
@classmethod
|
|
def get_plain_tensors(cls, qtensor):
|
|
return qtensor._qdata, qtensor._layout_params['scale']
|
|
|
|
QUANT_ALGOS = {
|
|
"float8_e4m3fn": {
|
|
"storage_t": torch.float8_e4m3fn,
|
|
"parameters": {"weight_scale", "input_scale"},
|
|
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
|
},
|
|
}
|
|
|
|
LAYOUTS = {
|
|
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
|
}
|
|
|
|
|
|
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
|
def fp8_linear(func, args, kwargs):
|
|
input_tensor = args[0]
|
|
weight = args[1]
|
|
bias = args[2] if len(args) > 2 else None
|
|
|
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
|
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
|
|
|
out_dtype = kwargs.get("out_dtype")
|
|
if out_dtype is None:
|
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
|
|
|
weight_t = plain_weight.t()
|
|
|
|
tensor_2d = False
|
|
if len(plain_input.shape) == 2:
|
|
tensor_2d = True
|
|
plain_input = plain_input.unsqueeze(1)
|
|
|
|
input_shape = plain_input.shape
|
|
if len(input_shape) != 3:
|
|
return None
|
|
|
|
try:
|
|
output = torch._scaled_mm(
|
|
plain_input.reshape(-1, input_shape[2]).contiguous(),
|
|
weight_t,
|
|
bias=bias,
|
|
scale_a=scale_a,
|
|
scale_b=scale_b,
|
|
out_dtype=out_dtype,
|
|
)
|
|
|
|
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
|
output = output[0]
|
|
|
|
if not tensor_2d:
|
|
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
|
|
|
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
output_scale = scale_a * scale_b
|
|
output_params = {
|
|
'scale': output_scale,
|
|
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
|
}
|
|
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
|
else:
|
|
return output
|
|
|
|
except Exception as e:
|
|
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
|
|
|
# Case 2: DQ Fallback
|
|
if isinstance(weight, QuantizedTensor):
|
|
weight = weight.dequantize()
|
|
if isinstance(input_tensor, QuantizedTensor):
|
|
input_tensor = input_tensor.dequantize()
|
|
|
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
|
|
|
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
|
if out_dtype is None:
|
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
|
|
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
|
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
|
|
|
output = torch._scaled_mm(
|
|
plain_input.contiguous(),
|
|
plain_weight,
|
|
bias=bias,
|
|
scale_a=scale_a,
|
|
scale_b=scale_b,
|
|
out_dtype=out_dtype,
|
|
)
|
|
|
|
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
|
output = output[0]
|
|
return output
|
|
|
|
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
|
def fp8_addmm(func, args, kwargs):
|
|
input_tensor = args[1]
|
|
weight = args[2]
|
|
bias = args[0]
|
|
|
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
|
|
|
a = list(args)
|
|
if isinstance(args[0], QuantizedTensor):
|
|
a[0] = args[0].dequantize()
|
|
if isinstance(args[1], QuantizedTensor):
|
|
a[1] = args[1].dequantize()
|
|
if isinstance(args[2], QuantizedTensor):
|
|
a[2] = args[2].dequantize()
|
|
|
|
return func(*a, **kwargs)
|
|
|
|
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
|
def fp8_mm(func, args, kwargs):
|
|
input_tensor = args[0]
|
|
weight = args[1]
|
|
|
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
|
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
|
|
|
a = list(args)
|
|
if isinstance(args[0], QuantizedTensor):
|
|
a[0] = args[0].dequantize()
|
|
if isinstance(args[1], QuantizedTensor):
|
|
a[1] = args[1].dequantize()
|
|
return func(*a, **kwargs)
|
|
|
|
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
|
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
|
def fp8_func(func, args, kwargs):
|
|
input_tensor = args[0]
|
|
if isinstance(input_tensor, QuantizedTensor):
|
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
|
ar = list(args)
|
|
ar[0] = plain_input
|
|
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
|
return func(*args, **kwargs)
|