An actually functional POC

This commit is contained in:
lspindler 2025-10-24 14:44:54 +02:00
parent 19ce6b056d
commit b6e0a53c11
4 changed files with 481 additions and 438 deletions

View File

@ -7,121 +7,24 @@ import logging
import torch
# ==============================================================================
# Quantization Detection Functions
# ==============================================================================
def normalize_layer_name(full_key, known_prefixes):
"""
Strip model prefix and parameter suffix from a state dict key.
Args:
full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight")
known_prefixes: List of known model prefixes to strip
Returns:
Normalized layer name (e.g., "layer1")
"""
name = full_key
# Strip model prefix
for prefix in known_prefixes:
if name.startswith(prefix):
name = name[len(prefix):]
break
# Remove parameter suffix
for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]:
if name.endswith(suffix):
name = name[:-len(suffix)]
break
return name
def detect_layer_quantization(state_dict, prefix="model.diffusion_model."):
"""
Detect per-layer quantization configuration from state dict.
Detection priority:
1. Check for _quantization_metadata key (new format)
2. Check for scaled_fp8 key (legacy format - return None)
3. Check for per-layer scale_weight patterns (mixed detection)
4. No quantization detected (return None)
Args:
state_dict: Model state dictionary
prefix: Key prefix for model layers
Returns:
Dict mapping layer names to quantization configs, or None for legacy/no quantization.
Example return value:
{
"input_blocks.5.1.transformer_blocks.0.attn1.to_q": {
"format": "fp8_e4m3fn_scaled",
"params": {"use_fp8_matmul": True}
},
"middle_block.1.transformer_blocks.0.attn2.to_k": {
"format": "fp8_e5m2_scaled",
"params": {"use_fp8_matmul": True}
}
}
"""
# 1. Check for new metadata format
metadata_key = f"{prefix}_quantization_metadata"
if metadata_key in state_dict:
try:
metadata = state_dict.pop(metadata_key)
if isinstance(metadata, dict) and "layers" in metadata:
logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})")
return metadata["layers"]
else:
logging.warning(f"Invalid quantization metadata format, ignoring")
except Exception as e:
logging.error(f"Failed to parse quantization metadata: {e}")
return None
def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."):
# 1. Check for per-layer config in metadata
quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata:
quant_metadata = metadata.pop(quant_key)
quant_metadata = json.loads(quant_metadata)
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
return quant_metadata["layers"]
else:
raise ValueError(f"Invalid quantization metadata format")
# 2. Check for legacy scaled_fp8 marker
# If present, return None to use legacy code path
scaled_fp8_key = f"{prefix}scaled_fp8"
if scaled_fp8_key in state_dict:
logging.debug("Detected legacy scaled_fp8 format, using legacy code path")
return None
# 3. Check for per-layer scale patterns (mixed precision without metadata)
# Look for layers that have scale_weight but not all layers have it
known_prefixes = [prefix]
layer_configs = {}
layers_with_scale = set()
layers_with_weight = set()
for key in state_dict.keys():
if key.startswith(prefix):
if key.endswith(".scale_weight"):
layer_name = normalize_layer_name(key, known_prefixes)
layers_with_scale.add(layer_name)
# Detect format based on weight dtype
weight_key = f"{prefix}{layer_name}.weight"
if weight_key in state_dict:
weight_dtype = state_dict[weight_key].dtype
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
format_name = "fp8_e4m3fn" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2"
layer_configs[layer_name] = {
"format": format_name,
"params": {}
}
elif key.endswith(".weight") and not key.endswith(".scale_weight"):
layer_name = normalize_layer_name(key, known_prefixes)
layers_with_weight.add(layer_name)
# If we found scale_weight on some but not all layers, it's mixed precision
if layer_configs and len(layers_with_scale) < len(layers_with_weight):
logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers")
return layer_configs
# 4. No quantization detected
return None
@ -821,7 +724,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix)
layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix)
if layer_quant_config:
model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")

View File

@ -345,19 +345,8 @@ class manual_cast(disable_weight_init):
def fp8_linear(self, input):
"""
Legacy FP8 linear function - now uses tensor subclass infrastructure.
This function maintains backward compatibility with existing code while
routing all FP8 computation through the unified tensor subclass system.
All actual FP8 matmul logic is handled by the registered operation handlers
in quant_ops.py via __torch_dispatch__.
Args:
self: Linear layer with FP8 weight and scale parameters
input: Input tensor (any dtype)
Returns:
Output tensor or None if weight is not FP8
Legacy FP8 linear function for backward compatibility.
Uses QuantizedTensor subclass for dispatch.
"""
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
@ -372,10 +361,8 @@ def fp8_linear(self, input):
input_dtype = input.dtype
if len(input.shape) == 3:
# Get weight and bias using standard casting
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
# Get scales (same as before)
scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
@ -388,14 +375,13 @@ def fp8_linear(self, input):
else:
scale_input = scale_input.to(input.device)
# Wrap weight in QuantizedTensorFP8 - this enables unified dispatch
quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype)
quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype)
# Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py!
# This is the key unification: all FP8 computation goes through one path
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
# Reshape output
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
@ -492,183 +478,117 @@ if CUBLAS_IS_AVAILABLE:
return super().forward(*args, **kwargs)
# Import quantization operations from separate module
from .quant_ops import QuantizedTensorFP8
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
QUANT_FORMAT_MIXINS = {
"float8_e4m3fn": {
"dtype": torch.float8_e4m3fn,
"layout_type": TensorCoreFP8Layout,
"parameters": {
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
}
}
}
class MixedPrecisionOps(disable_weight_init):
"""
Operations class supporting per-layer quantization (mixed precision).
This class enables different layers to use different quantization formats
within the same model (e.g., some layers FP8, others BF16).
Layer-specific quantization is configured via _layer_quant_config class variable,
which is set by pick_operations() when a model has mixed precision.
"""
_layer_quant_config = {} # Class variable set by pick_operations()
class Linear(disable_weight_init.Linear):
"""Linear layer with optional per-layer quantization using tensor subclasses"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.quant_format = None
self.quant_scale = None
self._quantization_initialized = False
_layer_quant_config = {}
_compute_dtype = torch.bfloat16
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None
def reset_parameters(self):
# Don't allocate weights - return None like disable_weight_init
return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
"""
Called by PyTorch during load_state_dict.
Load weight and wrap in QuantizedTensorFP8 if this layer is quantized.
"""
# Call parent to load weight and bias first
super()._load_from_state_dict(
state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs
)
# After weight is loaded, wrap it if this layer is quantized
if not self._quantization_initialized:
# Normalize layer name from prefix
layer_name = prefix.rstrip('.')
# Strip known model prefixes
for model_prefix in ["model.diffusion_model.", "model.model.", "net."]:
if layer_name.startswith(model_prefix):
layer_name = layer_name[len(model_prefix):]
break
# Check if this layer has quantization config
if layer_name in MixedPrecisionOps._layer_quant_config:
config = MixedPrecisionOps._layer_quant_config[layer_name]
self.quant_format = config.get("format", "fp8_e4m3fn")
# Load scale parameter
scale_key = f"{prefix}scale_weight"
if scale_key in state_dict:
self.quant_scale = state_dict[scale_key]
# Wrap weight in QuantizedTensorFP8
if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
try:
# Determine original dtype (default to bfloat16)
orig_dtype = torch.bfloat16
# Wrap weight in quantized tensor subclass
quantized_weight = QuantizedTensorFP8(
self.weight.data,
self.quant_scale,
orig_dtype=orig_dtype
)
# Replace weight parameter with wrapped version
self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})")
except Exception as e:
logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}")
self.quant_format = None
self.quant_scale = None
else:
logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization")
self.quant_format = None
self.quant_scale = None
else:
logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict")
self.quant_format = None
self._quantization_initialized = True
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""Save layer parameters including quantization scale"""
# First unwrap the weight if it's quantized
if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8):
# Temporarily unwrap to save the raw FP8 data
quantized_tensor = self.weight.data
raw_fp8_data = quantized_tensor._raw_data
original_weight = self.weight
self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False)
# Call parent to save unwrapped weight
super()._save_to_state_dict(destination, prefix, keep_vars)
# Restore the wrapped weight
self.weight = original_weight
# Save the scale parameter
if self.quant_scale is not None:
destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach()
else:
# Standard path for non-quantized weights
super()._save_to_state_dict(destination, prefix, keep_vars)
def forward_comfy_cast_weights(self, input):
"""
Forward pass - tensor subclass handles dispatch automatically!
__torch_dispatch__ will route to registered handlers based on tensor types.
"""
weight, bias = cast_bias_weight(self, input)
# Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it!
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
"""Main forward pass"""
run_every_op()
# Same logic as disable_weight_init.Linear
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod
def conv_nd(s, dims, *args, **kwargs):
"""Create Conv layer (same as disable_weight_init)"""
if dims == 2:
return s.Conv2d(*args, **kwargs)
elif dims == 3:
return s.Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
mixin = QUANT_FORMAT_MIXINS[quant_format]
self.layout_type = mixin["layout_type"]
layout_params = {
'scale': state_dict.pop(f"{prefix}weight_scale", None),
'orig_dtype': MixedPrecisionOps._compute_dtype
}
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
requires_grad=False
)
for param_name, param_value in mixin["parameters"].items():
_v = state_dict.pop(f"{prefix}{param_name}", None)
if _v is None:
continue
setattr(self, param_name, _v.to(device=device))
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
self._forward(input, weight, bias)
def forward(self, input, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
"""
Select appropriate operations class for model.
NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3).
LEGACY: All other paths unchanged for backward compatibility.
Args:
weight_dtype: Weight storage dtype
compute_dtype: Computation dtype
load_device: Device for loading
disable_fast_fp8: Disable fast FP8 paths
fp8_optimizations: Enable FP8 optimizations
scaled_fp8: Legacy FP8 dtype marker
model_config: Model config object (optional, for mixed precision support)
Returns:
Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init)
"""
# NEW: Check for mixed precision
# If model_config.layer_quant_config exists, use new MixedPrecisionOps.
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
# MixedPrecisionOps._compute_dtype = compute_dtype # TODO
MixedPrecisionOps._compute_dtype = torch.bfloat16
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps
# LEGACY paths (unchanged)
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

View File

@ -1,42 +1,79 @@
import torch
import logging
from typing import Tuple, Dict
# ==============================================================================
# Global Operation Registry
# ==============================================================================
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
# Global operation registry: torch operation → handler function
_QUANT_OP_REGISTRY = {}
def register_quant_op(torch_op):
def register_layout_op(torch_op, layout_type):
"""
Decorator to register an operation handler.
Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example:
@register_quant_op(torch.ops.aten.linear.default)
def handle_linear_fp8(func, args, kwargs):
# Implementation
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
# FP8-specific linear implementation
...
"""
def decorator(handler_func):
_QUANT_OP_REGISTRY[torch_op] = handler_func
if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func
return decorator
def get_quant_handler(torch_op):
"""Get registered handler for an operation"""
return _QUANT_OP_REGISTRY.get(torch_op)
def register_generic_util(torch_op):
"""
Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def list_registered_ops():
"""List all registered quantized operations"""
return list(_QUANT_OP_REGISTRY.keys())
def _get_layout_from_args(args):
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
# ==============================================================================
# comfy_kitchen Integration
# ==============================================================================
def _move_layout_params_to_device(params, device):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
try:
import comfy_kitchen as ck
@ -53,106 +90,144 @@ except Exception as e:
logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks")
# ==============================================================================
# Quantized Tensor Subclass
# ==============================================================================
class QuantizedTensorFP8(torch.Tensor):
class QuantizedLayout:
"""
Tensor subclass for FP8 quantized data.
Automatically handles operations via __torch_dispatch__.
Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError(f"TensorLayout must implement dequantize()")
@classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, tensor, scale, orig_dtype=torch.bfloat16):
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized FP8 tensor.
Create a quantized tensor.
Args:
tensor: The FP8 tensor data (torch.float8_e4m3fn or e5m2)
scale: Scale factor for dequantization (scalar tensor)
orig_dtype: Original dtype before quantization
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_subclass(cls, tensor, require_grad=False)
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
def __init__(self, tensor, scale, orig_dtype=torch.bfloat16):
self._scale = scale
self._orig_dtype = orig_dtype
# Store a reference to prevent infinite recursion in dequantize
self._raw_data = tensor.contiguous()
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata.contiguous()
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
return (f"QuantizedTensorFP8(shape={self.shape}, "
f"scale={self._scale:.4f}, dtype={self._orig_dtype})")
layout_name = self._layout_type.__name__
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_q_data"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
@classmethod
def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn):
orig_dtype = tensor.dtype
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
tensor_fp8 = None
if _CK_AVAILABLE:
try:
tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype)
except Exception as e:
logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}")
if tensor_fp8 is None:
lp_amax = torch.finfo(fp8_dtype).max
tensor_scaled = tensor.float() / scale
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
return cls(tensor_fp8, scale, orig_dtype=orig_dtype)
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
@classmethod
def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn):
if strategy == "amax":
scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max
scale = scale.to(tensor.device, dtype=torch.float32)
else:
raise ValueError(f"Unknown quantization strategy: {strategy}. "
f"Supported: 'amax'")
return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype)
def dequantize(self) -> torch.Tensor:
return self._layout_type.dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
# Special case: skip dispatch for internal tensor operations
# that are used for unwrapping (to avoid recursion)
if func in [torch.ops.aten._to_copy.default, torch.ops.aten.detach.default]:
# For these ops, use the raw data to avoid recursion, but return QuantizedTensorFP8 for detach
if func == torch.ops.aten.detach.default and isinstance(args[0], QuantizedTensorFP8):
# Special handling for detach - return a new QuantizedTensorFP8
qt = args[0]
detached_data = qt._raw_data.detach()
return QuantizedTensorFP8(detached_data, qt._scale, qt._orig_dtype)
# For other ops, just unwrap
def unwrap(arg):
if isinstance(arg, QuantizedTensorFP8):
return arg._raw_data
return arg
new_args = tuple(unwrap(a) if not isinstance(a, (list, tuple, dict)) else a for a in args)
return func(*new_args, **kwargs)
# Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs)
# Look up registered handler for this operation
handler = _QUANT_OP_REGISTRY.get(func)
if handler:
return handler(func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler:
return handler(func, args, kwargs)
# No handler - dequantize and use standard path
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod
def _dequant_and_fallback(cls, func, args, kwargs):
"""Fallback: dequantize all quantized tensors"""
def dequant_arg(arg):
if isinstance(arg, QuantizedTensorFP8):
if isinstance(arg, QuantizedTensor):
return arg.dequantize()
elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg)
@ -161,75 +236,220 @@ class QuantizedTensorFP8(torch.Tensor):
new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
def dequantize(self) -> torch.Tensor:
plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype)
return plain_tensor * self._scale
def detach(self):
"""Detach returns a new QuantizedTensorFP8 (required for Parameter)"""
detached_data = self._raw_data.detach()
return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype)
# ==============================================================================
# Operation Handlers for Quantized Tensors
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
@register_quant_op(torch.ops.aten.linear.default)
def handle_linear_fp8(func, args, kwargs):
"""
Handle F.linear() with quantized inputs.
def _create_transformed_qtensor(qt, transform_fn):
new_data = transform_fn(qt._qdata)
new_params = _copy_layout_params(qt._layout_params)
return QuantizedTensor(new_data, qt._layout_type, new_params)
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_dtype is not None and target_dtype != qt.dtype:
logging.warning(
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
f"but not supported for quantized tensors. Ignoring dtype."
)
Supports:
- QuantizedTensorFP8 input + QuantizedTensorFP8 weight
- QuantizedTensorFP8 input + regular weight
- Regular input + QuantizedTensorFP8 weight
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata)
qt_dest._layout_type = src._layout_type
qt_dest._layout_params = _copy_layout_params(src._layout_params)
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn):
orig_dtype = tensor.dtype
if scale is None:
scale = torch.amax(tensor.abs()) / torch.finfo(fp8_dtype).max
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
if _CK_AVAILABLE and tensor.device.type == "cuda":
qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype)
else:
lp_amax = torch.finfo(fp8_dtype).max
tensor_scaled = tensor.float() / scale
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return qdata, layout_params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
return plain_tensor * scale
@classmethod
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype)
# Case 1: Both input and weight are FP8
if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8):
# Get plain tensors to avoid dispatch recursion
plain_input = input_tensor._raw_data
plain_weight = weight._raw_data
weight_t = plain_weight.t() # Keep as column-major for cuBLASLt
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try:
output = torch._scaled_mm(
plain_input,
plain_input.reshape(-1, input_shape[2]),
weight_t,
bias=bias,
scale_a=input_tensor._scale,
scale_b=weight._scale,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple):
output = output[0]
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = input_tensor._scale * weight._scale
return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it
output_scale = scale_a * scale_b
output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
else:
return output
except Exception as e:
logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: Only weight is quantized
if isinstance(weight, QuantizedTensorFP8):
weight_dq = weight.dequantize()
input_dq = input_tensor.dequantize() if isinstance(input_tensor, QuantizedTensorFP8) else input_tensor
return torch.nn.functional.linear(input_dq, weight_dq, bias)
# Case 3: Only input is quantized
elif isinstance(input_tensor, QuantizedTensorFP8):
input_dq = input_tensor.dequantize()
return torch.nn.functional.linear(input_dq, weight, bias)
# Case 4: Neither is quantized (shouldn't happen, but handle it)
else:
return torch.nn.functional.linear(input_tensor, weight, bias)
# Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)

View File

@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
"""
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "")
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None:
new_sd = sd
@ -1346,8 +1346,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))