mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-06 09:47:35 +08:00
Update for torch compile comfy kitchen.
This commit is contained in:
parent
b9c3ad1c93
commit
3c7b599222
22
comfy/ops.py
22
comfy/ops.py
@ -79,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if isinstance(input, QuantizedTensor):
|
if isinstance(input, QuantizedTensor):
|
||||||
dtype = input._layout_params["orig_dtype"]
|
dtype = input.params.orig_dtype
|
||||||
else:
|
else:
|
||||||
dtype = input.dtype
|
dtype = input.dtype
|
||||||
if bias_dtype is None:
|
if bias_dtype is None:
|
||||||
@ -488,11 +488,8 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
from .quant_ops import (
|
from .quant_ops import (
|
||||||
QuantizedTensor,
|
QuantizedTensor,
|
||||||
QUANT_ALGOS,
|
QUANT_ALGOS,
|
||||||
LAYOUTS,
|
|
||||||
TensorCoreFP8Layout,
|
TensorCoreFP8Layout,
|
||||||
TensorCoreFP8E4M3Layout,
|
get_layout_class,
|
||||||
TensorCoreFP8E5M2Layout,
|
|
||||||
TensorCoreNVFP4Layout
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -567,7 +564,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
qconfig = QUANT_ALGOS[self.quant_format]
|
qconfig = QUANT_ALGOS[self.quant_format]
|
||||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
layout_cls = LAYOUTS[self.layout_type]
|
layout_cls = get_layout_class(self.layout_type)
|
||||||
|
|
||||||
# Load format-specific parameters
|
# Load format-specific parameters
|
||||||
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
||||||
@ -599,7 +596,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||||
|
|
||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), layout_cls, params),
|
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -626,10 +623,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
layout_cls = self.weight._layout_cls
|
layout_cls = self.weight._layout_cls
|
||||||
|
|
||||||
# Check if it's any FP8 variant (E4M3 or E5M2)
|
# Check if it's any FP8 variant (E4M3 or E5M2)
|
||||||
if layout_cls in (TensorCoreFP8E4M3Layout, TensorCoreFP8E5M2Layout) or \
|
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
|
||||||
layout_cls.__name__ in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
|
|
||||||
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
|
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
|
||||||
elif layout_cls == TensorCoreNVFP4Layout or layout_cls.__name__ == "TensorCoreNVFP4Layout":
|
elif layout_cls == "TensorCoreNVFP4Layout":
|
||||||
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
|
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
|
||||||
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
|
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
|
||||||
|
|
||||||
@ -659,7 +655,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
if (getattr(self, 'layout_type', None) is not None and
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor)):
|
not isinstance(input, QuantizedTensor)):
|
||||||
layout_cls = LAYOUTS[self.layout_type]
|
|
||||||
|
|
||||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||||
if tensor_3d:
|
if tensor_3d:
|
||||||
@ -670,7 +665,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
|
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
|
||||||
|
|
||||||
# dtype is now implicit in the layout class
|
# dtype is now implicit in the layout class
|
||||||
input = QuantizedTensor.from_float(input, layout_cls, scale=getattr(self, 'input_scale', None))
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
|
||||||
|
|
||||||
output = self._forward(input, self.weight, self.bias)
|
output = self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
@ -688,9 +683,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||||
if getattr(self, 'layout_type', None) is not None:
|
if getattr(self, 'layout_type', None) is not None:
|
||||||
layout_cls = LAYOUTS[self.layout_type]
|
|
||||||
# dtype is now implicit in the layout class
|
# dtype is now implicit in the layout class
|
||||||
weight = QuantizedTensor.from_float(weight, layout_cls, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
|
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True)
|
||||||
else:
|
else:
|
||||||
weight = weight.to(self.weight.dtype)
|
weight = weight.to(self.weight.dtype)
|
||||||
if return_weight:
|
if return_weight:
|
||||||
|
|||||||
@ -6,65 +6,39 @@ from typing import Dict
|
|||||||
try:
|
try:
|
||||||
import comfy_kitchen as ck
|
import comfy_kitchen as ck
|
||||||
from comfy_kitchen.tensor import (
|
from comfy_kitchen.tensor import (
|
||||||
QuantizedTensor as _CKQuantizedTensor,
|
QuantizedTensor,
|
||||||
QuantizedLayout,
|
QuantizedLayout,
|
||||||
TensorCoreFP8Layout as _CKFp8Layout,
|
TensorCoreFP8Layout as _CKFp8Layout,
|
||||||
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
|
TensorCoreNVFP4Layout, # Direct import, no wrapper needed
|
||||||
register_layout_op,
|
register_layout_op,
|
||||||
|
register_layout_class,
|
||||||
|
get_layout_class,
|
||||||
)
|
)
|
||||||
_CK_AVAILABLE = True
|
_CK_AVAILABLE = True
|
||||||
|
ck.registry.disable("triton")
|
||||||
for k, v in ck.list_backends().items():
|
for k, v in ck.list_backends().items():
|
||||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logging.info(f"Failed to import comfy_kitchen, falling back to torch ops. Error: {e}")
|
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
|
||||||
_CK_AVAILABLE = False
|
_CK_AVAILABLE = False
|
||||||
raise ImportError(f"comfy_kitchen is required but not available: {e}")
|
|
||||||
|
class QuantizedTensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class _CKFp8Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TensorCoreNVFP4Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def register_layout_class(name, cls):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_layout_class(name):
|
||||||
|
return None
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
|
||||||
# Backward Compatibility Layer
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
class QuantizedTensor(_CKQuantizedTensor):
|
|
||||||
@staticmethod
|
|
||||||
def __new__(cls, qdata, layout_cls, params):
|
|
||||||
# Backward compat: Convert string layout names and dict params before __new__
|
|
||||||
if isinstance(layout_cls, str):
|
|
||||||
layout_cls = LAYOUTS[layout_cls]
|
|
||||||
|
|
||||||
if isinstance(params, dict):
|
|
||||||
params = layout_cls.Params(**params)
|
|
||||||
|
|
||||||
return _CKQuantizedTensor.__new__(cls, qdata, layout_cls, params)
|
|
||||||
|
|
||||||
def __init__(self, qdata, layout_cls, params):
|
|
||||||
super().__init__(qdata, layout_cls, params)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _layout_params(self) -> Dict:
|
|
||||||
return dataclasses.asdict(self._params)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _layout_type(self) -> str:
|
|
||||||
return self._layout_cls.__name__
|
|
||||||
|
|
||||||
@property
|
|
||||||
def layout_type(self) -> str:
|
|
||||||
"""Backward compatibility alias for _layout_type."""
|
|
||||||
return self._layout_type
|
|
||||||
|
|
||||||
def _copy_with(self, qdata=None, params=None, clone_params=True):
|
|
||||||
if params is None:
|
|
||||||
params = self._params.clone() if clone_params else self._params
|
|
||||||
return type(self)(
|
|
||||||
qdata if qdata is not None else self._qdata,
|
|
||||||
self._layout_cls,
|
|
||||||
params,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# FP8 Layouts with Comfy-Specific Extensions
|
# FP8 Layouts with Comfy-Specific Extensions
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -81,7 +55,10 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
|||||||
orig_shape = tuple(tensor.shape)
|
orig_shape = tuple(tensor.shape)
|
||||||
|
|
||||||
if isinstance(scale, str) and scale == "recalculate":
|
if isinstance(scale, str) and scale == "recalculate":
|
||||||
scale = torch.amax(tensor.abs()) / torch.finfo(cls.FP8_DTYPE).max
|
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max
|
||||||
|
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
|
||||||
|
tensor_info = torch.finfo(tensor.dtype)
|
||||||
|
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
|
||||||
|
|
||||||
if scale is None:
|
if scale is None:
|
||||||
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||||
@ -97,7 +74,7 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
|||||||
else:
|
else:
|
||||||
qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE)
|
qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE)
|
||||||
|
|
||||||
params = cls.Params(scale=scale, orig_dtype=orig_dtype, orig_shape=orig_shape)
|
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
|
||||||
return qdata, params
|
return qdata, params
|
||||||
|
|
||||||
|
|
||||||
@ -117,12 +94,10 @@ TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
|
|||||||
# Registry
|
# Registry
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
LAYOUTS = {
|
register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||||
"TensorCoreFP8Layout": TensorCoreFP8Layout, # Backward compat alias (E4M3)
|
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||||
"TensorCoreFP8E4M3Layout": TensorCoreFP8E4M3Layout,
|
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||||
"TensorCoreFP8E5M2Layout": TensorCoreFP8E5M2Layout,
|
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||||
"TensorCoreNVFP4Layout": TensorCoreNVFP4Layout, # Direct from comfy_kitchen
|
|
||||||
}
|
|
||||||
|
|
||||||
QUANT_ALGOS = {
|
QUANT_ALGOS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user