mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 08:22:36 +08:00
An actually functional POC
This commit is contained in:
parent
19ce6b056d
commit
b6e0a53c11
@ -7,121 +7,24 @@ import logging
|
||||
import torch
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Quantization Detection Functions
|
||||
# ==============================================================================
|
||||
|
||||
def normalize_layer_name(full_key, known_prefixes):
|
||||
"""
|
||||
Strip model prefix and parameter suffix from a state dict key.
|
||||
|
||||
Args:
|
||||
full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight")
|
||||
known_prefixes: List of known model prefixes to strip
|
||||
|
||||
Returns:
|
||||
Normalized layer name (e.g., "layer1")
|
||||
"""
|
||||
name = full_key
|
||||
|
||||
# Strip model prefix
|
||||
for prefix in known_prefixes:
|
||||
if name.startswith(prefix):
|
||||
name = name[len(prefix):]
|
||||
break
|
||||
|
||||
# Remove parameter suffix
|
||||
for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]:
|
||||
if name.endswith(suffix):
|
||||
name = name[:-len(suffix)]
|
||||
break
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def detect_layer_quantization(state_dict, prefix="model.diffusion_model."):
|
||||
"""
|
||||
Detect per-layer quantization configuration from state dict.
|
||||
|
||||
Detection priority:
|
||||
1. Check for _quantization_metadata key (new format)
|
||||
2. Check for scaled_fp8 key (legacy format - return None)
|
||||
3. Check for per-layer scale_weight patterns (mixed detection)
|
||||
4. No quantization detected (return None)
|
||||
|
||||
Args:
|
||||
state_dict: Model state dictionary
|
||||
prefix: Key prefix for model layers
|
||||
|
||||
Returns:
|
||||
Dict mapping layer names to quantization configs, or None for legacy/no quantization.
|
||||
|
||||
Example return value:
|
||||
{
|
||||
"input_blocks.5.1.transformer_blocks.0.attn1.to_q": {
|
||||
"format": "fp8_e4m3fn_scaled",
|
||||
"params": {"use_fp8_matmul": True}
|
||||
},
|
||||
"middle_block.1.transformer_blocks.0.attn2.to_k": {
|
||||
"format": "fp8_e5m2_scaled",
|
||||
"params": {"use_fp8_matmul": True}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# 1. Check for new metadata format
|
||||
metadata_key = f"{prefix}_quantization_metadata"
|
||||
if metadata_key in state_dict:
|
||||
try:
|
||||
metadata = state_dict.pop(metadata_key)
|
||||
if isinstance(metadata, dict) and "layers" in metadata:
|
||||
logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})")
|
||||
return metadata["layers"]
|
||||
else:
|
||||
logging.warning(f"Invalid quantization metadata format, ignoring")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to parse quantization metadata: {e}")
|
||||
return None
|
||||
def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."):
|
||||
# 1. Check for per-layer config in metadata
|
||||
quant_key = "_quantization_metadata"
|
||||
if metadata is not None and quant_key in metadata:
|
||||
quant_metadata = metadata.pop(quant_key)
|
||||
quant_metadata = json.loads(quant_metadata)
|
||||
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||
return quant_metadata["layers"]
|
||||
else:
|
||||
raise ValueError(f"Invalid quantization metadata format")
|
||||
|
||||
# 2. Check for legacy scaled_fp8 marker
|
||||
# If present, return None to use legacy code path
|
||||
scaled_fp8_key = f"{prefix}scaled_fp8"
|
||||
if scaled_fp8_key in state_dict:
|
||||
logging.debug("Detected legacy scaled_fp8 format, using legacy code path")
|
||||
return None
|
||||
|
||||
# 3. Check for per-layer scale patterns (mixed precision without metadata)
|
||||
# Look for layers that have scale_weight but not all layers have it
|
||||
known_prefixes = [prefix]
|
||||
layer_configs = {}
|
||||
layers_with_scale = set()
|
||||
layers_with_weight = set()
|
||||
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix):
|
||||
if key.endswith(".scale_weight"):
|
||||
layer_name = normalize_layer_name(key, known_prefixes)
|
||||
layers_with_scale.add(layer_name)
|
||||
# Detect format based on weight dtype
|
||||
weight_key = f"{prefix}{layer_name}.weight"
|
||||
if weight_key in state_dict:
|
||||
weight_dtype = state_dict[weight_key].dtype
|
||||
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
format_name = "fp8_e4m3fn" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2"
|
||||
layer_configs[layer_name] = {
|
||||
"format": format_name,
|
||||
"params": {}
|
||||
}
|
||||
elif key.endswith(".weight") and not key.endswith(".scale_weight"):
|
||||
layer_name = normalize_layer_name(key, known_prefixes)
|
||||
layers_with_weight.add(layer_name)
|
||||
|
||||
# If we found scale_weight on some but not all layers, it's mixed precision
|
||||
if layer_configs and len(layers_with_scale) < len(layers_with_weight):
|
||||
logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers")
|
||||
return layer_configs
|
||||
|
||||
# 4. No quantization detected
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -821,7 +724,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
# Detect per-layer quantization (mixed precision)
|
||||
layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix)
|
||||
layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix)
|
||||
if layer_quant_config:
|
||||
model_config.layer_quant_config = layer_quant_config
|
||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||
|
||||
280
comfy/ops.py
280
comfy/ops.py
@ -345,19 +345,8 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
def fp8_linear(self, input):
|
||||
"""
|
||||
Legacy FP8 linear function - now uses tensor subclass infrastructure.
|
||||
|
||||
This function maintains backward compatibility with existing code while
|
||||
routing all FP8 computation through the unified tensor subclass system.
|
||||
All actual FP8 matmul logic is handled by the registered operation handlers
|
||||
in quant_ops.py via __torch_dispatch__.
|
||||
|
||||
Args:
|
||||
self: Linear layer with FP8 weight and scale parameters
|
||||
input: Input tensor (any dtype)
|
||||
|
||||
Returns:
|
||||
Output tensor or None if weight is not FP8
|
||||
Legacy FP8 linear function for backward compatibility.
|
||||
Uses QuantizedTensor subclass for dispatch.
|
||||
"""
|
||||
dtype = self.weight.dtype
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
@ -372,10 +361,8 @@ def fp8_linear(self, input):
|
||||
input_dtype = input.dtype
|
||||
|
||||
if len(input.shape) == 3:
|
||||
# Get weight and bias using standard casting
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||
|
||||
# Get scales (same as before)
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
if scale_weight is None:
|
||||
@ -388,14 +375,13 @@ def fp8_linear(self, input):
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
|
||||
# Wrap weight in QuantizedTensorFP8 - this enables unified dispatch
|
||||
quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype)
|
||||
quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype)
|
||||
# Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py!
|
||||
# This is the key unification: all FP8 computation goes through one path
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
||||
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype)
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
|
||||
# Reshape output
|
||||
if tensor_2d:
|
||||
return o.reshape(input_shape[0], -1)
|
||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||
@ -492,183 +478,117 @@ if CUBLAS_IS_AVAILABLE:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
# Import quantization operations from separate module
|
||||
from .quant_ops import QuantizedTensorFP8
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Mixed Precision Operations
|
||||
# ==============================================================================
|
||||
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||
|
||||
QUANT_FORMAT_MIXINS = {
|
||||
"float8_e4m3fn": {
|
||||
"dtype": torch.float8_e4m3fn,
|
||||
"layout_type": TensorCoreFP8Layout,
|
||||
"parameters": {
|
||||
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
"""
|
||||
Operations class supporting per-layer quantization (mixed precision).
|
||||
|
||||
This class enables different layers to use different quantization formats
|
||||
within the same model (e.g., some layers FP8, others BF16).
|
||||
|
||||
Layer-specific quantization is configured via _layer_quant_config class variable,
|
||||
which is set by pick_operations() when a model has mixed precision.
|
||||
"""
|
||||
|
||||
_layer_quant_config = {} # Class variable set by pick_operations()
|
||||
|
||||
class Linear(disable_weight_init.Linear):
|
||||
"""Linear layer with optional per-layer quantization using tensor subclasses"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.quant_format = None
|
||||
self.quant_scale = None
|
||||
self._quantization_initialized = False
|
||||
|
||||
_layer_quant_config = {}
|
||||
_compute_dtype = torch.bfloat16
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.tensor_class = None
|
||||
|
||||
def reset_parameters(self):
|
||||
# Don't allocate weights - return None like disable_weight_init
|
||||
return None
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
"""
|
||||
Called by PyTorch during load_state_dict.
|
||||
Load weight and wrap in QuantizedTensorFP8 if this layer is quantized.
|
||||
"""
|
||||
# Call parent to load weight and bias first
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
# After weight is loaded, wrap it if this layer is quantized
|
||||
if not self._quantization_initialized:
|
||||
# Normalize layer name from prefix
|
||||
layer_name = prefix.rstrip('.')
|
||||
|
||||
# Strip known model prefixes
|
||||
for model_prefix in ["model.diffusion_model.", "model.model.", "net."]:
|
||||
if layer_name.startswith(model_prefix):
|
||||
layer_name = layer_name[len(model_prefix):]
|
||||
break
|
||||
|
||||
# Check if this layer has quantization config
|
||||
if layer_name in MixedPrecisionOps._layer_quant_config:
|
||||
config = MixedPrecisionOps._layer_quant_config[layer_name]
|
||||
self.quant_format = config.get("format", "fp8_e4m3fn")
|
||||
|
||||
# Load scale parameter
|
||||
scale_key = f"{prefix}scale_weight"
|
||||
if scale_key in state_dict:
|
||||
self.quant_scale = state_dict[scale_key]
|
||||
|
||||
# Wrap weight in QuantizedTensorFP8
|
||||
if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
try:
|
||||
# Determine original dtype (default to bfloat16)
|
||||
orig_dtype = torch.bfloat16
|
||||
|
||||
# Wrap weight in quantized tensor subclass
|
||||
quantized_weight = QuantizedTensorFP8(
|
||||
self.weight.data,
|
||||
self.quant_scale,
|
||||
orig_dtype=orig_dtype
|
||||
)
|
||||
|
||||
# Replace weight parameter with wrapped version
|
||||
self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
|
||||
|
||||
logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})")
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}")
|
||||
self.quant_format = None
|
||||
self.quant_scale = None
|
||||
else:
|
||||
logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization")
|
||||
self.quant_format = None
|
||||
self.quant_scale = None
|
||||
else:
|
||||
logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict")
|
||||
self.quant_format = None
|
||||
|
||||
self._quantization_initialized = True
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
"""Save layer parameters including quantization scale"""
|
||||
# First unwrap the weight if it's quantized
|
||||
if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8):
|
||||
# Temporarily unwrap to save the raw FP8 data
|
||||
quantized_tensor = self.weight.data
|
||||
raw_fp8_data = quantized_tensor._raw_data
|
||||
original_weight = self.weight
|
||||
self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False)
|
||||
|
||||
# Call parent to save unwrapped weight
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
# Restore the wrapped weight
|
||||
self.weight = original_weight
|
||||
|
||||
# Save the scale parameter
|
||||
if self.quant_scale is not None:
|
||||
destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach()
|
||||
else:
|
||||
# Standard path for non-quantized weights
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
"""
|
||||
Forward pass - tensor subclass handles dispatch automatically!
|
||||
__torch_dispatch__ will route to registered handlers based on tensor types.
|
||||
"""
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
|
||||
# Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it!
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Main forward pass"""
|
||||
run_every_op()
|
||||
# Same logic as disable_weight_init.Linear
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def conv_nd(s, dims, *args, **kwargs):
|
||||
"""Create Conv layer (same as disable_weight_init)"""
|
||||
if dims == 2:
|
||||
return s.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return s.Conv3d(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip('.')
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||
|
||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||
if quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
||||
self.layout_type = mixin["layout_type"]
|
||||
|
||||
layout_params = {
|
||||
'scale': state_dict.pop(f"{prefix}weight_scale", None),
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype
|
||||
}
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name, param_value in mixin["parameters"].items():
|
||||
_v = state_dict.pop(f"{prefix}{param_name}", None)
|
||||
if _v is None:
|
||||
continue
|
||||
setattr(self, param_name, _v.to(device=device))
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
self._forward(input, weight, bias)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
run_every_op()
|
||||
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
getattr(self, 'input_scale', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||
"""
|
||||
Select appropriate operations class for model.
|
||||
|
||||
NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3).
|
||||
LEGACY: All other paths unchanged for backward compatibility.
|
||||
|
||||
Args:
|
||||
weight_dtype: Weight storage dtype
|
||||
compute_dtype: Computation dtype
|
||||
load_device: Device for loading
|
||||
disable_fast_fp8: Disable fast FP8 paths
|
||||
fp8_optimizations: Enable FP8 optimizations
|
||||
scaled_fp8: Legacy FP8 dtype marker
|
||||
model_config: Model config object (optional, for mixed precision support)
|
||||
|
||||
Returns:
|
||||
Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init)
|
||||
"""
|
||||
# NEW: Check for mixed precision
|
||||
# If model_config.layer_quant_config exists, use new MixedPrecisionOps.
|
||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
||||
# MixedPrecisionOps._compute_dtype = compute_dtype # TODO
|
||||
MixedPrecisionOps._compute_dtype = torch.bfloat16
|
||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||
return MixedPrecisionOps
|
||||
|
||||
# LEGACY paths (unchanged)
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
|
||||
@ -1,42 +1,79 @@
|
||||
import torch
|
||||
import logging
|
||||
from typing import Tuple, Dict
|
||||
|
||||
# ==============================================================================
|
||||
# Global Operation Registry
|
||||
# ==============================================================================
|
||||
_LAYOUT_REGISTRY = {}
|
||||
_GENERIC_UTILS = {}
|
||||
|
||||
# Global operation registry: torch operation → handler function
|
||||
_QUANT_OP_REGISTRY = {}
|
||||
|
||||
def register_quant_op(torch_op):
|
||||
def register_layout_op(torch_op, layout_type):
|
||||
"""
|
||||
Decorator to register an operation handler.
|
||||
|
||||
Decorator to register a layout-specific operation handler.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
||||
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
Example:
|
||||
@register_quant_op(torch.ops.aten.linear.default)
|
||||
def handle_linear_fp8(func, args, kwargs):
|
||||
# Implementation
|
||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||
def fp8_linear(func, args, kwargs):
|
||||
# FP8-specific linear implementation
|
||||
...
|
||||
"""
|
||||
def decorator(handler_func):
|
||||
_QUANT_OP_REGISTRY[torch_op] = handler_func
|
||||
if torch_op not in _LAYOUT_REGISTRY:
|
||||
_LAYOUT_REGISTRY[torch_op] = {}
|
||||
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
def get_quant_handler(torch_op):
|
||||
"""Get registered handler for an operation"""
|
||||
return _QUANT_OP_REGISTRY.get(torch_op)
|
||||
def register_generic_util(torch_op):
|
||||
"""
|
||||
Decorator to register a generic utility that works for all layouts.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
||||
|
||||
Example:
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
# Works for any layout
|
||||
...
|
||||
"""
|
||||
def decorator(handler_func):
|
||||
_GENERIC_UTILS[torch_op] = handler_func
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
def list_registered_ops():
|
||||
"""List all registered quantized operations"""
|
||||
return list(_QUANT_OP_REGISTRY.keys())
|
||||
def _get_layout_from_args(args):
|
||||
for arg in args:
|
||||
if isinstance(arg, QuantizedTensor):
|
||||
return arg._layout_type
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
for item in arg:
|
||||
if isinstance(item, QuantizedTensor):
|
||||
return item._layout_type
|
||||
return None
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# comfy_kitchen Integration
|
||||
# ==============================================================================
|
||||
def _move_layout_params_to_device(params, device):
|
||||
new_params = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params[k] = v.to(device=device)
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
|
||||
def _copy_layout_params(params):
|
||||
new_params = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params[k] = v.clone()
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
@ -53,106 +90,144 @@ except Exception as e:
|
||||
logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks")
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Quantized Tensor Subclass
|
||||
# ==============================================================================
|
||||
|
||||
class QuantizedTensorFP8(torch.Tensor):
|
||||
class QuantizedLayout:
|
||||
"""
|
||||
Tensor subclass for FP8 quantized data.
|
||||
Automatically handles operations via __torch_dispatch__.
|
||||
Base class for quantization layouts.
|
||||
|
||||
A layout encapsulates the format-specific logic for quantization/dequantization
|
||||
and provides a uniform interface for extracting raw tensors needed for computation.
|
||||
|
||||
New quantization formats should subclass this and implement the required methods.
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||
raise NotImplementedError(f"TensorLayout must implement dequantize()")
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
||||
|
||||
|
||||
class QuantizedTensor(torch.Tensor):
|
||||
"""
|
||||
Universal quantized tensor that works with any layout.
|
||||
|
||||
This tensor subclass uses a pluggable layout system to support multiple
|
||||
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
||||
|
||||
The layout_type determines format-specific behavior, while common operations
|
||||
(detach, clone, to) are handled generically.
|
||||
|
||||
Attributes:
|
||||
_qdata: The quantized tensor data
|
||||
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, tensor, scale, orig_dtype=torch.bfloat16):
|
||||
def __new__(cls, qdata, layout_type, layout_params):
|
||||
"""
|
||||
Create a quantized FP8 tensor.
|
||||
Create a quantized tensor.
|
||||
|
||||
Args:
|
||||
tensor: The FP8 tensor data (torch.float8_e4m3fn or e5m2)
|
||||
scale: Scale factor for dequantization (scalar tensor)
|
||||
orig_dtype: Original dtype before quantization
|
||||
qdata: The quantized data tensor
|
||||
layout_type: Layout class (subclass of QuantizedLayout)
|
||||
layout_params: Dict with layout-specific parameters
|
||||
"""
|
||||
return torch.Tensor._make_subclass(cls, tensor, require_grad=False)
|
||||
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
|
||||
|
||||
def __init__(self, tensor, scale, orig_dtype=torch.bfloat16):
|
||||
self._scale = scale
|
||||
self._orig_dtype = orig_dtype
|
||||
# Store a reference to prevent infinite recursion in dequantize
|
||||
self._raw_data = tensor.contiguous()
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata.contiguous()
|
||||
self._layout_type = layout_type
|
||||
self._layout_params = layout_params
|
||||
|
||||
def __repr__(self):
|
||||
return (f"QuantizedTensorFP8(shape={self.shape}, "
|
||||
f"scale={self._scale:.4f}, dtype={self._orig_dtype})")
|
||||
layout_name = self._layout_type.__name__
|
||||
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||
|
||||
@property
|
||||
def layout_type(self):
|
||||
return self._layout_type
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
"""
|
||||
Tensor flattening protocol for proper device movement.
|
||||
"""
|
||||
inner_tensors = ["_q_data"]
|
||||
ctx = {
|
||||
"layout_type": self._layout_type,
|
||||
}
|
||||
|
||||
tensor_params = {}
|
||||
non_tensor_params = {}
|
||||
for k, v in self._layout_params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
tensor_params[k] = v
|
||||
else:
|
||||
non_tensor_params[k] = v
|
||||
|
||||
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
||||
ctx["non_tensor_params"] = non_tensor_params
|
||||
|
||||
for k, v in tensor_params.items():
|
||||
attr_name = f"_layout_param_{k}"
|
||||
object.__setattr__(self, attr_name, v)
|
||||
inner_tensors.append(attr_name)
|
||||
|
||||
return inner_tensors, ctx
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
||||
"""
|
||||
Tensor unflattening protocol for proper device movement.
|
||||
Reconstructs the QuantizedTensor after device movement.
|
||||
"""
|
||||
layout_type = ctx["layout_type"]
|
||||
layout_params = dict(ctx["non_tensor_params"])
|
||||
|
||||
for key in ctx["tensor_param_keys"]:
|
||||
attr_name = f"_layout_param_{key}"
|
||||
layout_params[key] = inner_tensors[attr_name]
|
||||
|
||||
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
|
||||
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn):
|
||||
orig_dtype = tensor.dtype
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
|
||||
|
||||
tensor_fp8 = None
|
||||
if _CK_AVAILABLE:
|
||||
try:
|
||||
tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype)
|
||||
except Exception as e:
|
||||
logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}")
|
||||
|
||||
if tensor_fp8 is None:
|
||||
lp_amax = torch.finfo(fp8_dtype).max
|
||||
tensor_scaled = tensor.float() / scale
|
||||
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||
tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
|
||||
|
||||
return cls(tensor_fp8, scale, orig_dtype=orig_dtype)
|
||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
|
||||
return cls(qdata, layout_type, layout_params)
|
||||
|
||||
@classmethod
|
||||
def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn):
|
||||
if strategy == "amax":
|
||||
scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max
|
||||
scale = scale.to(tensor.device, dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization strategy: {strategy}. "
|
||||
f"Supported: 'amax'")
|
||||
|
||||
return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype)
|
||||
def dequantize(self) -> torch.Tensor:
|
||||
return self._layout_type.dequantize(self._qdata, **self._layout_params)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# Special case: skip dispatch for internal tensor operations
|
||||
# that are used for unwrapping (to avoid recursion)
|
||||
if func in [torch.ops.aten._to_copy.default, torch.ops.aten.detach.default]:
|
||||
# For these ops, use the raw data to avoid recursion, but return QuantizedTensorFP8 for detach
|
||||
if func == torch.ops.aten.detach.default and isinstance(args[0], QuantizedTensorFP8):
|
||||
# Special handling for detach - return a new QuantizedTensorFP8
|
||||
qt = args[0]
|
||||
detached_data = qt._raw_data.detach()
|
||||
return QuantizedTensorFP8(detached_data, qt._scale, qt._orig_dtype)
|
||||
|
||||
# For other ops, just unwrap
|
||||
def unwrap(arg):
|
||||
if isinstance(arg, QuantizedTensorFP8):
|
||||
return arg._raw_data
|
||||
return arg
|
||||
new_args = tuple(unwrap(a) if not isinstance(a, (list, tuple, dict)) else a for a in args)
|
||||
return func(*new_args, **kwargs)
|
||||
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
||||
if func in _GENERIC_UTILS:
|
||||
return _GENERIC_UTILS[func](func, args, kwargs)
|
||||
|
||||
# Look up registered handler for this operation
|
||||
handler = _QUANT_OP_REGISTRY.get(func)
|
||||
if handler:
|
||||
return handler(func, args, kwargs)
|
||||
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
||||
layout_type = _get_layout_from_args(args)
|
||||
if layout_type and func in _LAYOUT_REGISTRY:
|
||||
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
||||
if handler:
|
||||
return handler(func, args, kwargs)
|
||||
|
||||
# No handler - dequantize and use standard path
|
||||
# Step 3: Fallback to dequantization
|
||||
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||
logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||
return cls._dequant_and_fallback(func, args, kwargs)
|
||||
|
||||
@classmethod
|
||||
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||
"""Fallback: dequantize all quantized tensors"""
|
||||
def dequant_arg(arg):
|
||||
if isinstance(arg, QuantizedTensorFP8):
|
||||
if isinstance(arg, QuantizedTensor):
|
||||
return arg.dequantize()
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
return type(arg)(dequant_arg(a) for a in arg)
|
||||
@ -161,75 +236,220 @@ class QuantizedTensorFP8(torch.Tensor):
|
||||
new_args = dequant_arg(args)
|
||||
new_kwargs = dequant_arg(kwargs)
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
def dequantize(self) -> torch.Tensor:
|
||||
plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype)
|
||||
return plain_tensor * self._scale
|
||||
|
||||
def detach(self):
|
||||
"""Detach returns a new QuantizedTensorFP8 (required for Parameter)"""
|
||||
detached_data = self._raw_data.detach()
|
||||
return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Operation Handlers for Quantized Tensors
|
||||
# Generic Utilities (Layout-Agnostic Operations)
|
||||
# ==============================================================================
|
||||
|
||||
@register_quant_op(torch.ops.aten.linear.default)
|
||||
def handle_linear_fp8(func, args, kwargs):
|
||||
"""
|
||||
Handle F.linear() with quantized inputs.
|
||||
def _create_transformed_qtensor(qt, transform_fn):
|
||||
new_data = transform_fn(qt._qdata)
|
||||
new_params = _copy_layout_params(qt._layout_params)
|
||||
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
||||
|
||||
|
||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||
if target_dtype is not None and target_dtype != qt.dtype:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||
f"but not supported for quantized tensors. Ignoring dtype."
|
||||
)
|
||||
|
||||
Supports:
|
||||
- QuantizedTensorFP8 input + QuantizedTensorFP8 weight
|
||||
- QuantizedTensorFP8 input + regular weight
|
||||
- Regular input + QuantizedTensorFP8 weight
|
||||
if target_layout is not None and target_layout != torch.strided:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||
f"but not supported. Ignoring layout."
|
||||
)
|
||||
|
||||
# Handle device transfer
|
||||
current_device = qt._qdata.device
|
||||
if target_device is not None:
|
||||
# Normalize device for comparison
|
||||
if isinstance(target_device, str):
|
||||
target_device = torch.device(target_device)
|
||||
if isinstance(current_device, str):
|
||||
current_device = torch.device(current_device)
|
||||
|
||||
if target_device != current_device:
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||
new_q_data = qt._qdata.to(device=target_device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||
return new_qt
|
||||
|
||||
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
||||
return qt
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
"""Detach operation - creates a detached copy of the quantized tensor."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.clone.default)
|
||||
def generic_clone(func, args, kwargs):
|
||||
"""Clone operation - creates a deep copy of the quantized tensor."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._to_copy.default)
|
||||
def generic_to_copy(func, args, kwargs):
|
||||
"""Device/dtype transfer operation - handles .to(device) calls."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _handle_device_transfer(
|
||||
qt,
|
||||
target_device=kwargs.get('device', None),
|
||||
target_dtype=kwargs.get('dtype', None),
|
||||
op_name="_to_copy"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
||||
def generic_to_dtype_layout(func, args, kwargs):
|
||||
"""Handle .to(device) calls using the dtype_layout variant."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _handle_device_transfer(
|
||||
qt,
|
||||
target_device=kwargs.get('device', None),
|
||||
target_dtype=kwargs.get('dtype', None),
|
||||
target_layout=kwargs.get('layout', None),
|
||||
op_name="to"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.copy_.default)
|
||||
def generic_copy_(func, args, kwargs):
|
||||
qt_dest = args[0]
|
||||
src = args[1]
|
||||
|
||||
if isinstance(qt_dest, QuantizedTensor):
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# Copy from another quantized tensor
|
||||
qt_dest._qdata.copy_(src._qdata)
|
||||
qt_dest._layout_type = src._layout_type
|
||||
qt_dest._layout_params = _copy_layout_params(src._layout_params)
|
||||
else:
|
||||
# Copy from regular tensor - just copy raw data
|
||||
qt_dest._qdata.copy_(src)
|
||||
return qt_dest
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||
return True
|
||||
|
||||
# ==============================================================================
|
||||
# FP8 Layout + Operation Handlers
|
||||
# ==============================================================================
|
||||
class TensorCoreFP8Layout(QuantizedLayout):
|
||||
"""
|
||||
Storage format:
|
||||
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||
- scale: Scalar tensor (float32) for dequantization
|
||||
- orig_dtype: Original dtype before quantization (for casting back)
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn):
|
||||
orig_dtype = tensor.dtype
|
||||
|
||||
if scale is None:
|
||||
scale = torch.amax(tensor.abs()) / torch.finfo(fp8_dtype).max
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
if _CK_AVAILABLE and tensor.device.type == "cuda":
|
||||
qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype)
|
||||
else:
|
||||
lp_amax = torch.finfo(fp8_dtype).max
|
||||
tensor_scaled = tensor.float() / scale
|
||||
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||
qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
|
||||
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
'orig_dtype': orig_dtype
|
||||
}
|
||||
return qdata, layout_params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||
return plain_tensor * scale
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor):
|
||||
return qtensor._qdata, qtensor._layout_params['scale']
|
||||
|
||||
|
||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||
def fp8_linear(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
bias = args[2] if len(args) > 2 else None
|
||||
out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype)
|
||||
|
||||
# Case 1: Both input and weight are FP8
|
||||
if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8):
|
||||
# Get plain tensors to avoid dispatch recursion
|
||||
plain_input = input_tensor._raw_data
|
||||
plain_weight = weight._raw_data
|
||||
weight_t = plain_weight.t() # Keep as column-major for cuBLASLt
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
out_dtype = kwargs.get("out_dtype")
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
weight_t = plain_weight.t()
|
||||
|
||||
tensor_2d = False
|
||||
if len(plain_input.shape) == 2:
|
||||
tensor_2d = True
|
||||
plain_input = plain_input.unsqueeze(1)
|
||||
|
||||
input_shape = plain_input.shape
|
||||
if len(input_shape) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
output = torch._scaled_mm(
|
||||
plain_input,
|
||||
plain_input.reshape(-1, input_shape[2]),
|
||||
weight_t,
|
||||
bias=bias,
|
||||
scale_a=input_tensor._scale,
|
||||
scale_b=weight._scale,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
|
||||
if not tensor_2d:
|
||||
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
||||
|
||||
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
output_scale = input_tensor._scale * weight._scale
|
||||
return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it
|
||||
output_scale = scale_a * scale_b
|
||||
output_params = {
|
||||
'scale': output_scale,
|
||||
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||
}
|
||||
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
|
||||
else:
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||
|
||||
# Case 2: Only weight is quantized
|
||||
if isinstance(weight, QuantizedTensorFP8):
|
||||
weight_dq = weight.dequantize()
|
||||
input_dq = input_tensor.dequantize() if isinstance(input_tensor, QuantizedTensorFP8) else input_tensor
|
||||
return torch.nn.functional.linear(input_dq, weight_dq, bias)
|
||||
|
||||
# Case 3: Only input is quantized
|
||||
elif isinstance(input_tensor, QuantizedTensorFP8):
|
||||
input_dq = input_tensor.dequantize()
|
||||
return torch.nn.functional.linear(input_dq, weight, bias)
|
||||
|
||||
# Case 4: Neither is quantized (shouldn't happen, but handle it)
|
||||
else:
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
# Case 2: DQ Fallback
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
if isinstance(input_tensor, QuantizedTensor):
|
||||
input_tensor = input_tensor.dequantize()
|
||||
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
|
||||
@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
"""
|
||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||
|
||||
@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||
|
||||
if model_config is not None:
|
||||
new_sd = sd
|
||||
@ -1346,8 +1346,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user