An actually functional POC

This commit is contained in:
lspindler 2025-10-24 14:44:54 +02:00
parent 19ce6b056d
commit b6e0a53c11
4 changed files with 481 additions and 438 deletions

View File

@ -7,121 +7,24 @@ import logging
import torch import torch
# ============================================================================== def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."):
# Quantization Detection Functions # 1. Check for per-layer config in metadata
# ============================================================================== quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata:
def normalize_layer_name(full_key, known_prefixes): quant_metadata = metadata.pop(quant_key)
""" quant_metadata = json.loads(quant_metadata)
Strip model prefix and parameter suffix from a state dict key. if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
Args: return quant_metadata["layers"]
full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight") else:
known_prefixes: List of known model prefixes to strip raise ValueError(f"Invalid quantization metadata format")
Returns:
Normalized layer name (e.g., "layer1")
"""
name = full_key
# Strip model prefix
for prefix in known_prefixes:
if name.startswith(prefix):
name = name[len(prefix):]
break
# Remove parameter suffix
for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]:
if name.endswith(suffix):
name = name[:-len(suffix)]
break
return name
def detect_layer_quantization(state_dict, prefix="model.diffusion_model."):
"""
Detect per-layer quantization configuration from state dict.
Detection priority:
1. Check for _quantization_metadata key (new format)
2. Check for scaled_fp8 key (legacy format - return None)
3. Check for per-layer scale_weight patterns (mixed detection)
4. No quantization detected (return None)
Args:
state_dict: Model state dictionary
prefix: Key prefix for model layers
Returns:
Dict mapping layer names to quantization configs, or None for legacy/no quantization.
Example return value:
{
"input_blocks.5.1.transformer_blocks.0.attn1.to_q": {
"format": "fp8_e4m3fn_scaled",
"params": {"use_fp8_matmul": True}
},
"middle_block.1.transformer_blocks.0.attn2.to_k": {
"format": "fp8_e5m2_scaled",
"params": {"use_fp8_matmul": True}
}
}
"""
# 1. Check for new metadata format
metadata_key = f"{prefix}_quantization_metadata"
if metadata_key in state_dict:
try:
metadata = state_dict.pop(metadata_key)
if isinstance(metadata, dict) and "layers" in metadata:
logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})")
return metadata["layers"]
else:
logging.warning(f"Invalid quantization metadata format, ignoring")
except Exception as e:
logging.error(f"Failed to parse quantization metadata: {e}")
return None
# 2. Check for legacy scaled_fp8 marker # 2. Check for legacy scaled_fp8 marker
# If present, return None to use legacy code path
scaled_fp8_key = f"{prefix}scaled_fp8" scaled_fp8_key = f"{prefix}scaled_fp8"
if scaled_fp8_key in state_dict: if scaled_fp8_key in state_dict:
logging.debug("Detected legacy scaled_fp8 format, using legacy code path") logging.debug("Detected legacy scaled_fp8 format, using legacy code path")
return None return None
# 3. Check for per-layer scale patterns (mixed precision without metadata)
# Look for layers that have scale_weight but not all layers have it
known_prefixes = [prefix]
layer_configs = {}
layers_with_scale = set()
layers_with_weight = set()
for key in state_dict.keys():
if key.startswith(prefix):
if key.endswith(".scale_weight"):
layer_name = normalize_layer_name(key, known_prefixes)
layers_with_scale.add(layer_name)
# Detect format based on weight dtype
weight_key = f"{prefix}{layer_name}.weight"
if weight_key in state_dict:
weight_dtype = state_dict[weight_key].dtype
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
format_name = "fp8_e4m3fn" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2"
layer_configs[layer_name] = {
"format": format_name,
"params": {}
}
elif key.endswith(".weight") and not key.endswith(".scale_weight"):
layer_name = normalize_layer_name(key, known_prefixes)
layers_with_weight.add(layer_name)
# If we found scale_weight on some but not all layers, it's mixed precision
if layer_configs and len(layers_with_scale) < len(layers_with_weight):
logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers")
return layer_configs
# 4. No quantization detected
return None return None
@ -821,7 +724,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
model_config.optimizations["fp8"] = True model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision) # Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix) layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix)
if layer_quant_config: if layer_quant_config:
model_config.layer_quant_config = layer_quant_config model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")

View File

@ -345,19 +345,8 @@ class manual_cast(disable_weight_init):
def fp8_linear(self, input): def fp8_linear(self, input):
""" """
Legacy FP8 linear function - now uses tensor subclass infrastructure. Legacy FP8 linear function for backward compatibility.
Uses QuantizedTensor subclass for dispatch.
This function maintains backward compatibility with existing code while
routing all FP8 computation through the unified tensor subclass system.
All actual FP8 matmul logic is handled by the registered operation handlers
in quant_ops.py via __torch_dispatch__.
Args:
self: Linear layer with FP8 weight and scale parameters
input: Input tensor (any dtype)
Returns:
Output tensor or None if weight is not FP8
""" """
dtype = self.weight.dtype dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]: if dtype not in [torch.float8_e4m3fn]:
@ -372,10 +361,8 @@ def fp8_linear(self, input):
input_dtype = input.dtype input_dtype = input.dtype
if len(input.shape) == 3: if len(input.shape) == 3:
# Get weight and bias using standard casting
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
# Get scales (same as before)
scale_weight = self.scale_weight scale_weight = self.scale_weight
scale_input = self.scale_input scale_input = self.scale_input
if scale_weight is None: if scale_weight is None:
@ -388,14 +375,13 @@ def fp8_linear(self, input):
else: else:
scale_input = scale_input.to(input.device) scale_input = scale_input.to(input.device)
# Wrap weight in QuantizedTensorFP8 - this enables unified dispatch # Wrap weight in QuantizedTensor - this enables unified dispatch
quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype) layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
# Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
# This is the key unification: all FP8 computation goes through one path quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
# Reshape output
if tensor_2d: if tensor_2d:
return o.reshape(input_shape[0], -1) return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0])) return o.reshape((-1, input_shape[1], self.weight.shape[0]))
@ -492,183 +478,117 @@ if CUBLAS_IS_AVAILABLE:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
# Import quantization operations from separate module
from .quant_ops import QuantizedTensorFP8
# ============================================================================== # ==============================================================================
# Mixed Precision Operations # Mixed Precision Operations
# ============================================================================== # ==============================================================================
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
QUANT_FORMAT_MIXINS = {
"float8_e4m3fn": {
"dtype": torch.float8_e4m3fn,
"layout_type": TensorCoreFP8Layout,
"parameters": {
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
}
}
}
class MixedPrecisionOps(disable_weight_init): class MixedPrecisionOps(disable_weight_init):
""" _layer_quant_config = {}
Operations class supporting per-layer quantization (mixed precision). _compute_dtype = torch.bfloat16
This class enables different layers to use different quantization formats class Linear(torch.nn.Module, CastWeightBiasOp):
within the same model (e.g., some layers FP8, others BF16). def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
Layer-specific quantization is configured via _layer_quant_config class variable, self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
which is set by pick_operations() when a model has mixed precision. # self.factory_kwargs = {"device": device, "dtype": dtype}
"""
_layer_quant_config = {} # Class variable set by pick_operations() self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
class Linear(disable_weight_init.Linear): self.tensor_class = None
"""Linear layer with optional per-layer quantization using tensor subclasses"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.quant_format = None
self.quant_scale = None
self._quantization_initialized = False
def reset_parameters(self): def reset_parameters(self):
# Don't allocate weights - return None like disable_weight_init
return None return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs): strict, missing_keys, unexpected_keys, error_msgs):
"""
Called by PyTorch during load_state_dict.
Load weight and wrap in QuantizedTensorFP8 if this layer is quantized.
"""
# Call parent to load weight and bias first
super()._load_from_state_dict(
state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs
)
# After weight is loaded, wrap it if this layer is quantized device = self.factory_kwargs["device"]
if not self._quantization_initialized: layer_name = prefix.rstrip('.')
# Normalize layer name from prefix weight_key = f"{prefix}weight"
layer_name = prefix.rstrip('.') weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
# Strip known model prefixes if layer_name not in MixedPrecisionOps._layer_quant_config:
for model_prefix in ["model.diffusion_model.", "model.model.", "net."]: self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
if layer_name.startswith(model_prefix):
layer_name = layer_name[len(model_prefix):]
break
# Check if this layer has quantization config
if layer_name in MixedPrecisionOps._layer_quant_config:
config = MixedPrecisionOps._layer_quant_config[layer_name]
self.quant_format = config.get("format", "fp8_e4m3fn")
# Load scale parameter
scale_key = f"{prefix}scale_weight"
if scale_key in state_dict:
self.quant_scale = state_dict[scale_key]
# Wrap weight in QuantizedTensorFP8
if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
try:
# Determine original dtype (default to bfloat16)
orig_dtype = torch.bfloat16
# Wrap weight in quantized tensor subclass
quantized_weight = QuantizedTensorFP8(
self.weight.data,
self.quant_scale,
orig_dtype=orig_dtype
)
# Replace weight parameter with wrapped version
self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})")
except Exception as e:
logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}")
self.quant_format = None
self.quant_scale = None
else:
logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization")
self.quant_format = None
self.quant_scale = None
else:
logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict")
self.quant_format = None
self._quantization_initialized = True
def _save_to_state_dict(self, destination, prefix, keep_vars):
"""Save layer parameters including quantization scale"""
# First unwrap the weight if it's quantized
if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8):
# Temporarily unwrap to save the raw FP8 data
quantized_tensor = self.weight.data
raw_fp8_data = quantized_tensor._raw_data
original_weight = self.weight
self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False)
# Call parent to save unwrapped weight
super()._save_to_state_dict(destination, prefix, keep_vars)
# Restore the wrapped weight
self.weight = original_weight
# Save the scale parameter
if self.quant_scale is not None:
destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach()
else: else:
# Standard path for non-quantized weights quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
super()._save_to_state_dict(destination, prefix, keep_vars) if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
def forward_comfy_cast_weights(self, input): mixin = QUANT_FORMAT_MIXINS[quant_format]
""" self.layout_type = mixin["layout_type"]
Forward pass - tensor subclass handles dispatch automatically!
__torch_dispatch__ will route to registered handlers based on tensor types.
"""
weight, bias = cast_bias_weight(self, input)
# Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it! layout_params = {
'scale': state_dict.pop(f"{prefix}weight_scale", None),
'orig_dtype': MixedPrecisionOps._compute_dtype
}
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
requires_grad=False
)
for param_name, param_value in mixin["parameters"].items():
_v = state_dict.pop(f"{prefix}{param_name}", None)
if _v is None:
continue
setattr(self, param_name, _v.to(device=device))
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs): def forward_comfy_cast_weights(self, input):
"""Main forward pass""" weight, bias = cast_bias_weight(self, input)
run_every_op() self._forward(input, weight, bias)
# Same logic as disable_weight_init.Linear
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@classmethod def forward(self, input, *args, **kwargs):
def conv_nd(s, dims, *args, **kwargs): run_every_op()
"""Create Conv layer (same as disable_weight_init)"""
if dims == 2: if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return s.Conv2d(*args, **kwargs) return self.forward_comfy_cast_weights(input, *args, **kwargs)
elif dims == 3: if (getattr(self, 'layout_type', None) is not None and
return s.Conv3d(*args, **kwargs) getattr(self, 'input_scale', None) is not None and
else: not isinstance(input, QuantizedTensor)):
raise ValueError(f"unsupported dimensions: {dims}") input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
""" # If model_config.layer_quant_config exists, use new MixedPrecisionOps.
Select appropriate operations class for model.
NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3).
LEGACY: All other paths unchanged for backward compatibility.
Args:
weight_dtype: Weight storage dtype
compute_dtype: Computation dtype
load_device: Device for loading
disable_fast_fp8: Disable fast FP8 paths
fp8_optimizations: Enable FP8 optimizations
scaled_fp8: Legacy FP8 dtype marker
model_config: Model config object (optional, for mixed precision support)
Returns:
Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init)
"""
# NEW: Check for mixed precision
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
# MixedPrecisionOps._compute_dtype = compute_dtype # TODO
MixedPrecisionOps._compute_dtype = torch.bfloat16
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps return MixedPrecisionOps
# LEGACY paths (unchanged)
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None: if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

View File

@ -1,42 +1,79 @@
import torch import torch
import logging import logging
from typing import Tuple, Dict
# ============================================================================== _LAYOUT_REGISTRY = {}
# Global Operation Registry _GENERIC_UTILS = {}
# ==============================================================================
# Global operation registry: torch operation → handler function
_QUANT_OP_REGISTRY = {}
def register_quant_op(torch_op): def register_layout_op(torch_op, layout_type):
""" """
Decorator to register an operation handler. Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example: Example:
@register_quant_op(torch.ops.aten.linear.default) @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def handle_linear_fp8(func, args, kwargs): def fp8_linear(func, args, kwargs):
# Implementation # FP8-specific linear implementation
... ...
""" """
def decorator(handler_func): def decorator(handler_func):
_QUANT_OP_REGISTRY[torch_op] = handler_func if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func return handler_func
return decorator return decorator
def get_quant_handler(torch_op): def register_generic_util(torch_op):
"""Get registered handler for an operation""" """
return _QUANT_OP_REGISTRY.get(torch_op) Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def list_registered_ops(): def _get_layout_from_args(args):
"""List all registered quantized operations""" for arg in args:
return list(_QUANT_OP_REGISTRY.keys()) if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
# ============================================================================== def _move_layout_params_to_device(params, device):
# comfy_kitchen Integration new_params = {}
# ============================================================================== for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
try: try:
import comfy_kitchen as ck import comfy_kitchen as ck
@ -53,106 +90,144 @@ except Exception as e:
logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks") logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks")
# ============================================================================== class QuantizedLayout:
# Quantized Tensor Subclass
# ==============================================================================
class QuantizedTensorFP8(torch.Tensor):
""" """
Tensor subclass for FP8 quantized data. Base class for quantization layouts.
Automatically handles operations via __torch_dispatch__.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError(f"TensorLayout must implement dequantize()")
@classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
""" """
@staticmethod @staticmethod
def __new__(cls, tensor, scale, orig_dtype=torch.bfloat16): def __new__(cls, qdata, layout_type, layout_params):
""" """
Create a quantized FP8 tensor. Create a quantized tensor.
Args: Args:
tensor: The FP8 tensor data (torch.float8_e4m3fn or e5m2) qdata: The quantized data tensor
scale: Scale factor for dequantization (scalar tensor) layout_type: Layout class (subclass of QuantizedLayout)
orig_dtype: Original dtype before quantization layout_params: Dict with layout-specific parameters
""" """
return torch.Tensor._make_subclass(cls, tensor, require_grad=False) return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
def __init__(self, tensor, scale, orig_dtype=torch.bfloat16): def __init__(self, qdata, layout_type, layout_params):
self._scale = scale self._qdata = qdata.contiguous()
self._orig_dtype = orig_dtype self._layout_type = layout_type
# Store a reference to prevent infinite recursion in dequantize self._layout_params = layout_params
self._raw_data = tensor.contiguous()
def __repr__(self): def __repr__(self):
return (f"QuantizedTensorFP8(shape={self.shape}, " layout_name = self._layout_type.__name__
f"scale={self._scale:.4f}, dtype={self._orig_dtype})") param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_q_data"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
@classmethod @classmethod
def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn): def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
orig_dtype = tensor.dtype qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
if not isinstance(scale, torch.Tensor): def dequantize(self) -> torch.Tensor:
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) return self._layout_type.dequantize(self._qdata, **self._layout_params)
tensor_fp8 = None
if _CK_AVAILABLE:
try:
tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype)
except Exception as e:
logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}")
if tensor_fp8 is None:
lp_amax = torch.finfo(fp8_dtype).max
tensor_scaled = tensor.float() / scale
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
return cls(tensor_fp8, scale, orig_dtype=orig_dtype)
@classmethod
def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn):
if strategy == "amax":
scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max
scale = scale.to(tensor.device, dtype=torch.float32)
else:
raise ValueError(f"Unknown quantization strategy: {strategy}. "
f"Supported: 'amax'")
return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype)
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {} kwargs = kwargs or {}
# Special case: skip dispatch for internal tensor operations # Step 1: Check generic utilities first (detach, clone, to, etc.)
# that are used for unwrapping (to avoid recursion) if func in _GENERIC_UTILS:
if func in [torch.ops.aten._to_copy.default, torch.ops.aten.detach.default]: return _GENERIC_UTILS[func](func, args, kwargs)
# For these ops, use the raw data to avoid recursion, but return QuantizedTensorFP8 for detach
if func == torch.ops.aten.detach.default and isinstance(args[0], QuantizedTensorFP8):
# Special handling for detach - return a new QuantizedTensorFP8
qt = args[0]
detached_data = qt._raw_data.detach()
return QuantizedTensorFP8(detached_data, qt._scale, qt._orig_dtype)
# For other ops, just unwrap # Step 2: Check layout-specific handlers (linear, matmul, etc.)
def unwrap(arg): layout_type = _get_layout_from_args(args)
if isinstance(arg, QuantizedTensorFP8): if layout_type and func in _LAYOUT_REGISTRY:
return arg._raw_data handler = _LAYOUT_REGISTRY[func].get(layout_type)
return arg if handler:
new_args = tuple(unwrap(a) if not isinstance(a, (list, tuple, dict)) else a for a in args) return handler(func, args, kwargs)
return func(*new_args, **kwargs)
# Look up registered handler for this operation # Step 3: Fallback to dequantization
handler = _QUANT_OP_REGISTRY.get(func) if isinstance(args[0] if args else None, QuantizedTensor):
if handler: logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return handler(func, args, kwargs)
# No handler - dequantize and use standard path
return cls._dequant_and_fallback(func, args, kwargs) return cls._dequant_and_fallback(func, args, kwargs)
@classmethod @classmethod
def _dequant_and_fallback(cls, func, args, kwargs): def _dequant_and_fallback(cls, func, args, kwargs):
"""Fallback: dequantize all quantized tensors"""
def dequant_arg(arg): def dequant_arg(arg):
if isinstance(arg, QuantizedTensorFP8): if isinstance(arg, QuantizedTensor):
return arg.dequantize() return arg.dequantize()
elif isinstance(arg, (list, tuple)): elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg) return type(arg)(dequant_arg(a) for a in arg)
@ -162,74 +237,219 @@ class QuantizedTensorFP8(torch.Tensor):
new_kwargs = dequant_arg(kwargs) new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs) return func(*new_args, **new_kwargs)
def dequantize(self) -> torch.Tensor:
plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype)
return plain_tensor * self._scale
def detach(self):
"""Detach returns a new QuantizedTensorFP8 (required for Parameter)"""
detached_data = self._raw_data.detach()
return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype)
# ============================================================================== # ==============================================================================
# Operation Handlers for Quantized Tensors # Generic Utilities (Layout-Agnostic Operations)
# ============================================================================== # ==============================================================================
@register_quant_op(torch.ops.aten.linear.default) def _create_transformed_qtensor(qt, transform_fn):
def handle_linear_fp8(func, args, kwargs): new_data = transform_fn(qt._qdata)
""" new_params = _copy_layout_params(qt._layout_params)
Handle F.linear() with quantized inputs. return QuantizedTensor(new_data, qt._layout_type, new_params)
Supports:
- QuantizedTensorFP8 input + QuantizedTensorFP8 weight def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
- QuantizedTensorFP8 input + regular weight if target_dtype is not None and target_dtype != qt.dtype:
- Regular input + QuantizedTensorFP8 weight logging.warning(
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
f"but not supported for quantized tensors. Ignoring dtype."
)
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata)
qt_dest._layout_type = src._layout_type
qt_dest._layout_params = _copy_layout_params(src._layout_params)
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
""" """
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn):
orig_dtype = tensor.dtype
if scale is None:
scale = torch.amax(tensor.abs()) / torch.finfo(fp8_dtype).max
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
if _CK_AVAILABLE and tensor.device.type == "cuda":
qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype)
else:
lp_amax = torch.finfo(fp8_dtype).max
tensor_scaled = tensor.float() / scale
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return qdata, layout_params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
return plain_tensor * scale
@classmethod
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
input_tensor = args[0] input_tensor = args[0]
weight = args[1] weight = args[1]
bias = args[2] if len(args) > 2 else None bias = args[2] if len(args) > 2 else None
out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype)
# Case 1: Both input and weight are FP8 if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
# Get plain tensors to avoid dispatch recursion plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
plain_input = input_tensor._raw_data
plain_weight = weight._raw_data out_dtype = kwargs.get("out_dtype")
weight_t = plain_weight.t() # Keep as column-major for cuBLASLt if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try: try:
output = torch._scaled_mm( output = torch._scaled_mm(
plain_input, plain_input.reshape(-1, input_shape[2]),
weight_t, weight_t,
bias=bias, bias=bias,
scale_a=input_tensor._scale, scale_a=scale_a,
scale_b=weight._scale, scale_b=scale_b,
out_dtype=out_dtype, out_dtype=out_dtype,
) )
if isinstance(output, tuple): if not tensor_2d:
output = output[0] output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = input_tensor._scale * weight._scale output_scale = scale_a * scale_b
return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
else: else:
return output return output
except Exception as e: except Exception as e:
logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: Only weight is quantized # Case 2: DQ Fallback
if isinstance(weight, QuantizedTensorFP8): if isinstance(weight, QuantizedTensor):
weight_dq = weight.dequantize() weight = weight.dequantize()
input_dq = input_tensor.dequantize() if isinstance(input_tensor, QuantizedTensorFP8) else input_tensor if isinstance(input_tensor, QuantizedTensor):
return torch.nn.functional.linear(input_dq, weight_dq, bias) input_tensor = input_tensor.dequantize()
# Case 3: Only input is quantized
elif isinstance(input_tensor, QuantizedTensorFP8):
input_dq = input_tensor.dequantize()
return torch.nn.functional.linear(input_dq, weight, bias)
# Case 4: Neither is quantized (shouldn't happen, but handle it)
else:
return torch.nn.functional.linear(input_tensor, weight, bias)
return torch.nn.functional.linear(input_tensor, weight, bias)

View File

@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
""" """
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
weight_dtype = comfy.utils.weight_dtype(sd) weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "") model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None: if model_config is not None:
new_sd = sd new_sd = sd
@ -1346,8 +1346,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model(unet_path, model_options={}): def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path) sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options) model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
if model is None: if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))