mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
An actually functional POC
This commit is contained in:
parent
19ce6b056d
commit
b6e0a53c11
@ -7,121 +7,24 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."):
|
||||||
# Quantization Detection Functions
|
# 1. Check for per-layer config in metadata
|
||||||
# ==============================================================================
|
quant_key = "_quantization_metadata"
|
||||||
|
if metadata is not None and quant_key in metadata:
|
||||||
def normalize_layer_name(full_key, known_prefixes):
|
quant_metadata = metadata.pop(quant_key)
|
||||||
"""
|
quant_metadata = json.loads(quant_metadata)
|
||||||
Strip model prefix and parameter suffix from a state dict key.
|
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
||||||
|
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||||
Args:
|
return quant_metadata["layers"]
|
||||||
full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight")
|
else:
|
||||||
known_prefixes: List of known model prefixes to strip
|
raise ValueError(f"Invalid quantization metadata format")
|
||||||
|
|
||||||
Returns:
|
|
||||||
Normalized layer name (e.g., "layer1")
|
|
||||||
"""
|
|
||||||
name = full_key
|
|
||||||
|
|
||||||
# Strip model prefix
|
|
||||||
for prefix in known_prefixes:
|
|
||||||
if name.startswith(prefix):
|
|
||||||
name = name[len(prefix):]
|
|
||||||
break
|
|
||||||
|
|
||||||
# Remove parameter suffix
|
|
||||||
for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]:
|
|
||||||
if name.endswith(suffix):
|
|
||||||
name = name[:-len(suffix)]
|
|
||||||
break
|
|
||||||
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def detect_layer_quantization(state_dict, prefix="model.diffusion_model."):
|
|
||||||
"""
|
|
||||||
Detect per-layer quantization configuration from state dict.
|
|
||||||
|
|
||||||
Detection priority:
|
|
||||||
1. Check for _quantization_metadata key (new format)
|
|
||||||
2. Check for scaled_fp8 key (legacy format - return None)
|
|
||||||
3. Check for per-layer scale_weight patterns (mixed detection)
|
|
||||||
4. No quantization detected (return None)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_dict: Model state dictionary
|
|
||||||
prefix: Key prefix for model layers
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping layer names to quantization configs, or None for legacy/no quantization.
|
|
||||||
|
|
||||||
Example return value:
|
|
||||||
{
|
|
||||||
"input_blocks.5.1.transformer_blocks.0.attn1.to_q": {
|
|
||||||
"format": "fp8_e4m3fn_scaled",
|
|
||||||
"params": {"use_fp8_matmul": True}
|
|
||||||
},
|
|
||||||
"middle_block.1.transformer_blocks.0.attn2.to_k": {
|
|
||||||
"format": "fp8_e5m2_scaled",
|
|
||||||
"params": {"use_fp8_matmul": True}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 1. Check for new metadata format
|
|
||||||
metadata_key = f"{prefix}_quantization_metadata"
|
|
||||||
if metadata_key in state_dict:
|
|
||||||
try:
|
|
||||||
metadata = state_dict.pop(metadata_key)
|
|
||||||
if isinstance(metadata, dict) and "layers" in metadata:
|
|
||||||
logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})")
|
|
||||||
return metadata["layers"]
|
|
||||||
else:
|
|
||||||
logging.warning(f"Invalid quantization metadata format, ignoring")
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to parse quantization metadata: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 2. Check for legacy scaled_fp8 marker
|
# 2. Check for legacy scaled_fp8 marker
|
||||||
# If present, return None to use legacy code path
|
|
||||||
scaled_fp8_key = f"{prefix}scaled_fp8"
|
scaled_fp8_key = f"{prefix}scaled_fp8"
|
||||||
if scaled_fp8_key in state_dict:
|
if scaled_fp8_key in state_dict:
|
||||||
logging.debug("Detected legacy scaled_fp8 format, using legacy code path")
|
logging.debug("Detected legacy scaled_fp8 format, using legacy code path")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 3. Check for per-layer scale patterns (mixed precision without metadata)
|
|
||||||
# Look for layers that have scale_weight but not all layers have it
|
|
||||||
known_prefixes = [prefix]
|
|
||||||
layer_configs = {}
|
|
||||||
layers_with_scale = set()
|
|
||||||
layers_with_weight = set()
|
|
||||||
|
|
||||||
for key in state_dict.keys():
|
|
||||||
if key.startswith(prefix):
|
|
||||||
if key.endswith(".scale_weight"):
|
|
||||||
layer_name = normalize_layer_name(key, known_prefixes)
|
|
||||||
layers_with_scale.add(layer_name)
|
|
||||||
# Detect format based on weight dtype
|
|
||||||
weight_key = f"{prefix}{layer_name}.weight"
|
|
||||||
if weight_key in state_dict:
|
|
||||||
weight_dtype = state_dict[weight_key].dtype
|
|
||||||
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
||||||
format_name = "fp8_e4m3fn" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2"
|
|
||||||
layer_configs[layer_name] = {
|
|
||||||
"format": format_name,
|
|
||||||
"params": {}
|
|
||||||
}
|
|
||||||
elif key.endswith(".weight") and not key.endswith(".scale_weight"):
|
|
||||||
layer_name = normalize_layer_name(key, known_prefixes)
|
|
||||||
layers_with_weight.add(layer_name)
|
|
||||||
|
|
||||||
# If we found scale_weight on some but not all layers, it's mixed precision
|
|
||||||
if layer_configs and len(layers_with_scale) < len(layers_with_weight):
|
|
||||||
logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers")
|
|
||||||
return layer_configs
|
|
||||||
|
|
||||||
# 4. No quantization detected
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -821,7 +724,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
# Detect per-layer quantization (mixed precision)
|
# Detect per-layer quantization (mixed precision)
|
||||||
layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix)
|
layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix)
|
||||||
if layer_quant_config:
|
if layer_quant_config:
|
||||||
model_config.layer_quant_config = layer_quant_config
|
model_config.layer_quant_config = layer_quant_config
|
||||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||||
|
|||||||
254
comfy/ops.py
254
comfy/ops.py
@ -345,19 +345,8 @@ class manual_cast(disable_weight_init):
|
|||||||
|
|
||||||
def fp8_linear(self, input):
|
def fp8_linear(self, input):
|
||||||
"""
|
"""
|
||||||
Legacy FP8 linear function - now uses tensor subclass infrastructure.
|
Legacy FP8 linear function for backward compatibility.
|
||||||
|
Uses QuantizedTensor subclass for dispatch.
|
||||||
This function maintains backward compatibility with existing code while
|
|
||||||
routing all FP8 computation through the unified tensor subclass system.
|
|
||||||
All actual FP8 matmul logic is handled by the registered operation handlers
|
|
||||||
in quant_ops.py via __torch_dispatch__.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
self: Linear layer with FP8 weight and scale parameters
|
|
||||||
input: Input tensor (any dtype)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Output tensor or None if weight is not FP8
|
|
||||||
"""
|
"""
|
||||||
dtype = self.weight.dtype
|
dtype = self.weight.dtype
|
||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
@ -372,10 +361,8 @@ def fp8_linear(self, input):
|
|||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
# Get weight and bias using standard casting
|
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||||
|
|
||||||
# Get scales (same as before)
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
scale_input = self.scale_input
|
scale_input = self.scale_input
|
||||||
if scale_weight is None:
|
if scale_weight is None:
|
||||||
@ -388,14 +375,13 @@ def fp8_linear(self, input):
|
|||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensorFP8 - this enables unified dispatch
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype)
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype)
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||||
# Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py!
|
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
||||||
# This is the key unification: all FP8 computation goes through one path
|
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype)
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
# Reshape output
|
|
||||||
if tensor_2d:
|
if tensor_2d:
|
||||||
return o.reshape(input_shape[0], -1)
|
return o.reshape(input_shape[0], -1)
|
||||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||||
@ -492,183 +478,117 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# Import quantization operations from separate module
|
|
||||||
from .quant_ops import QuantizedTensorFP8
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Mixed Precision Operations
|
# Mixed Precision Operations
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||||
|
|
||||||
|
QUANT_FORMAT_MIXINS = {
|
||||||
|
"float8_e4m3fn": {
|
||||||
|
"dtype": torch.float8_e4m3fn,
|
||||||
|
"layout_type": TensorCoreFP8Layout,
|
||||||
|
"parameters": {
|
||||||
|
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
|
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class MixedPrecisionOps(disable_weight_init):
|
class MixedPrecisionOps(disable_weight_init):
|
||||||
"""
|
_layer_quant_config = {}
|
||||||
Operations class supporting per-layer quantization (mixed precision).
|
_compute_dtype = torch.bfloat16
|
||||||
|
|
||||||
This class enables different layers to use different quantization formats
|
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||||
within the same model (e.g., some layers FP8, others BF16).
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
Layer-specific quantization is configured via _layer_quant_config class variable,
|
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||||
which is set by pick_operations() when a model has mixed precision.
|
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
"""
|
|
||||||
|
|
||||||
_layer_quant_config = {} # Class variable set by pick_operations()
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
class Linear(disable_weight_init.Linear):
|
self.tensor_class = None
|
||||||
"""Linear layer with optional per-layer quantization using tensor subclasses"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.quant_format = None
|
|
||||||
self.quant_scale = None
|
|
||||||
self._quantization_initialized = False
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
# Don't allocate weights - return None like disable_weight_init
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
"""
|
|
||||||
Called by PyTorch during load_state_dict.
|
|
||||||
Load weight and wrap in QuantizedTensorFP8 if this layer is quantized.
|
|
||||||
"""
|
|
||||||
# Call parent to load weight and bias first
|
|
||||||
super()._load_from_state_dict(
|
|
||||||
state_dict, prefix, local_metadata,
|
|
||||||
strict, missing_keys, unexpected_keys, error_msgs
|
|
||||||
)
|
|
||||||
|
|
||||||
# After weight is loaded, wrap it if this layer is quantized
|
device = self.factory_kwargs["device"]
|
||||||
if not self._quantization_initialized:
|
layer_name = prefix.rstrip('.')
|
||||||
# Normalize layer name from prefix
|
weight_key = f"{prefix}weight"
|
||||||
layer_name = prefix.rstrip('.')
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is None:
|
||||||
|
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||||
|
|
||||||
# Strip known model prefixes
|
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||||
for model_prefix in ["model.diffusion_model.", "model.model.", "net."]:
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
if layer_name.startswith(model_prefix):
|
|
||||||
layer_name = layer_name[len(model_prefix):]
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check if this layer has quantization config
|
|
||||||
if layer_name in MixedPrecisionOps._layer_quant_config:
|
|
||||||
config = MixedPrecisionOps._layer_quant_config[layer_name]
|
|
||||||
self.quant_format = config.get("format", "fp8_e4m3fn")
|
|
||||||
|
|
||||||
# Load scale parameter
|
|
||||||
scale_key = f"{prefix}scale_weight"
|
|
||||||
if scale_key in state_dict:
|
|
||||||
self.quant_scale = state_dict[scale_key]
|
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensorFP8
|
|
||||||
if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
||||||
try:
|
|
||||||
# Determine original dtype (default to bfloat16)
|
|
||||||
orig_dtype = torch.bfloat16
|
|
||||||
|
|
||||||
# Wrap weight in quantized tensor subclass
|
|
||||||
quantized_weight = QuantizedTensorFP8(
|
|
||||||
self.weight.data,
|
|
||||||
self.quant_scale,
|
|
||||||
orig_dtype=orig_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# Replace weight parameter with wrapped version
|
|
||||||
self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
|
|
||||||
|
|
||||||
logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})")
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}")
|
|
||||||
self.quant_format = None
|
|
||||||
self.quant_scale = None
|
|
||||||
else:
|
|
||||||
logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization")
|
|
||||||
self.quant_format = None
|
|
||||||
self.quant_scale = None
|
|
||||||
else:
|
|
||||||
logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict")
|
|
||||||
self.quant_format = None
|
|
||||||
|
|
||||||
self._quantization_initialized = True
|
|
||||||
|
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
||||||
"""Save layer parameters including quantization scale"""
|
|
||||||
# First unwrap the weight if it's quantized
|
|
||||||
if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8):
|
|
||||||
# Temporarily unwrap to save the raw FP8 data
|
|
||||||
quantized_tensor = self.weight.data
|
|
||||||
raw_fp8_data = quantized_tensor._raw_data
|
|
||||||
original_weight = self.weight
|
|
||||||
self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False)
|
|
||||||
|
|
||||||
# Call parent to save unwrapped weight
|
|
||||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
|
||||||
|
|
||||||
# Restore the wrapped weight
|
|
||||||
self.weight = original_weight
|
|
||||||
|
|
||||||
# Save the scale parameter
|
|
||||||
if self.quant_scale is not None:
|
|
||||||
destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach()
|
|
||||||
else:
|
else:
|
||||||
# Standard path for non-quantized weights
|
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
if quant_format is None:
|
||||||
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
||||||
"""
|
self.layout_type = mixin["layout_type"]
|
||||||
Forward pass - tensor subclass handles dispatch automatically!
|
|
||||||
__torch_dispatch__ will route to registered handlers based on tensor types.
|
|
||||||
"""
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
|
||||||
|
|
||||||
# Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it!
|
layout_params = {
|
||||||
|
'scale': state_dict.pop(f"{prefix}weight_scale", None),
|
||||||
|
'orig_dtype': MixedPrecisionOps._compute_dtype
|
||||||
|
}
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
|
||||||
|
requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
for param_name, param_value in mixin["parameters"].items():
|
||||||
|
_v = state_dict.pop(f"{prefix}{param_name}", None)
|
||||||
|
if _v is None:
|
||||||
|
continue
|
||||||
|
setattr(self, param_name, _v.to(device=device))
|
||||||
|
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
def _forward(self, input, weight, bias):
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward_comfy_cast_weights(self, input):
|
||||||
"""Main forward pass"""
|
weight, bias = cast_bias_weight(self, input)
|
||||||
run_every_op()
|
self._forward(input, weight, bias)
|
||||||
# Same logic as disable_weight_init.Linear
|
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return super().forward(*args, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
def forward(self, input, *args, **kwargs):
|
||||||
def conv_nd(s, dims, *args, **kwargs):
|
run_every_op()
|
||||||
"""Create Conv layer (same as disable_weight_init)"""
|
|
||||||
if dims == 2:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return s.Conv2d(*args, **kwargs)
|
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||||
elif dims == 3:
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
return s.Conv3d(*args, **kwargs)
|
getattr(self, 'input_scale', None) is not None and
|
||||||
else:
|
not isinstance(input, QuantizedTensor)):
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
||||||
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||||
"""
|
# If model_config.layer_quant_config exists, use new MixedPrecisionOps.
|
||||||
Select appropriate operations class for model.
|
|
||||||
|
|
||||||
NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3).
|
|
||||||
LEGACY: All other paths unchanged for backward compatibility.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
weight_dtype: Weight storage dtype
|
|
||||||
compute_dtype: Computation dtype
|
|
||||||
load_device: Device for loading
|
|
||||||
disable_fast_fp8: Disable fast FP8 paths
|
|
||||||
fp8_optimizations: Enable FP8 optimizations
|
|
||||||
scaled_fp8: Legacy FP8 dtype marker
|
|
||||||
model_config: Model config object (optional, for mixed precision support)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init)
|
|
||||||
"""
|
|
||||||
# NEW: Check for mixed precision
|
|
||||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||||
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
||||||
|
# MixedPrecisionOps._compute_dtype = compute_dtype # TODO
|
||||||
|
MixedPrecisionOps._compute_dtype = torch.bfloat16
|
||||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||||
return MixedPrecisionOps
|
return MixedPrecisionOps
|
||||||
|
|
||||||
# LEGACY paths (unchanged)
|
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||||
|
|||||||
@ -1,42 +1,79 @@
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple, Dict
|
||||||
|
|
||||||
# ==============================================================================
|
_LAYOUT_REGISTRY = {}
|
||||||
# Global Operation Registry
|
_GENERIC_UTILS = {}
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
# Global operation registry: torch operation → handler function
|
|
||||||
_QUANT_OP_REGISTRY = {}
|
|
||||||
|
|
||||||
def register_quant_op(torch_op):
|
def register_layout_op(torch_op, layout_type):
|
||||||
"""
|
"""
|
||||||
Decorator to register an operation handler.
|
Decorator to register a layout-specific operation handler.
|
||||||
|
Args:
|
||||||
|
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
||||||
|
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||||
Example:
|
Example:
|
||||||
@register_quant_op(torch.ops.aten.linear.default)
|
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||||
def handle_linear_fp8(func, args, kwargs):
|
def fp8_linear(func, args, kwargs):
|
||||||
# Implementation
|
# FP8-specific linear implementation
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
def decorator(handler_func):
|
def decorator(handler_func):
|
||||||
_QUANT_OP_REGISTRY[torch_op] = handler_func
|
if torch_op not in _LAYOUT_REGISTRY:
|
||||||
|
_LAYOUT_REGISTRY[torch_op] = {}
|
||||||
|
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
||||||
return handler_func
|
return handler_func
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def get_quant_handler(torch_op):
|
def register_generic_util(torch_op):
|
||||||
"""Get registered handler for an operation"""
|
"""
|
||||||
return _QUANT_OP_REGISTRY.get(torch_op)
|
Decorator to register a generic utility that works for all layouts.
|
||||||
|
Args:
|
||||||
|
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@register_generic_util(torch.ops.aten.detach.default)
|
||||||
|
def generic_detach(func, args, kwargs):
|
||||||
|
# Works for any layout
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(handler_func):
|
||||||
|
_GENERIC_UTILS[torch_op] = handler_func
|
||||||
|
return handler_func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def list_registered_ops():
|
def _get_layout_from_args(args):
|
||||||
"""List all registered quantized operations"""
|
for arg in args:
|
||||||
return list(_QUANT_OP_REGISTRY.keys())
|
if isinstance(arg, QuantizedTensor):
|
||||||
|
return arg._layout_type
|
||||||
|
elif isinstance(arg, (list, tuple)):
|
||||||
|
for item in arg:
|
||||||
|
if isinstance(item, QuantizedTensor):
|
||||||
|
return item._layout_type
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
def _move_layout_params_to_device(params, device):
|
||||||
# comfy_kitchen Integration
|
new_params = {}
|
||||||
# ==============================================================================
|
for k, v in params.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
new_params[k] = v.to(device=device)
|
||||||
|
else:
|
||||||
|
new_params[k] = v
|
||||||
|
return new_params
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_layout_params(params):
|
||||||
|
new_params = {}
|
||||||
|
for k, v in params.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
new_params[k] = v.clone()
|
||||||
|
else:
|
||||||
|
new_params[k] = v
|
||||||
|
return new_params
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import comfy_kitchen as ck
|
import comfy_kitchen as ck
|
||||||
@ -53,106 +90,144 @@ except Exception as e:
|
|||||||
logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks")
|
logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks")
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
class QuantizedLayout:
|
||||||
# Quantized Tensor Subclass
|
|
||||||
# ==============================================================================
|
|
||||||
|
|
||||||
class QuantizedTensorFP8(torch.Tensor):
|
|
||||||
"""
|
"""
|
||||||
Tensor subclass for FP8 quantized data.
|
Base class for quantization layouts.
|
||||||
Automatically handles operations via __torch_dispatch__.
|
|
||||||
|
A layout encapsulates the format-specific logic for quantization/dequantization
|
||||||
|
and provides a uniform interface for extracting raw tensors needed for computation.
|
||||||
|
|
||||||
|
New quantization formats should subclass this and implement the required methods.
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
||||||
|
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||||
|
raise NotImplementedError(f"TensorLayout must implement dequantize()")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedTensor(torch.Tensor):
|
||||||
|
"""
|
||||||
|
Universal quantized tensor that works with any layout.
|
||||||
|
|
||||||
|
This tensor subclass uses a pluggable layout system to support multiple
|
||||||
|
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
||||||
|
|
||||||
|
The layout_type determines format-specific behavior, while common operations
|
||||||
|
(detach, clone, to) are handled generically.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_qdata: The quantized tensor data
|
||||||
|
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||||
|
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __new__(cls, tensor, scale, orig_dtype=torch.bfloat16):
|
def __new__(cls, qdata, layout_type, layout_params):
|
||||||
"""
|
"""
|
||||||
Create a quantized FP8 tensor.
|
Create a quantized tensor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensor: The FP8 tensor data (torch.float8_e4m3fn or e5m2)
|
qdata: The quantized data tensor
|
||||||
scale: Scale factor for dequantization (scalar tensor)
|
layout_type: Layout class (subclass of QuantizedLayout)
|
||||||
orig_dtype: Original dtype before quantization
|
layout_params: Dict with layout-specific parameters
|
||||||
"""
|
"""
|
||||||
return torch.Tensor._make_subclass(cls, tensor, require_grad=False)
|
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
|
||||||
|
|
||||||
def __init__(self, tensor, scale, orig_dtype=torch.bfloat16):
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
self._scale = scale
|
self._qdata = qdata.contiguous()
|
||||||
self._orig_dtype = orig_dtype
|
self._layout_type = layout_type
|
||||||
# Store a reference to prevent infinite recursion in dequantize
|
self._layout_params = layout_params
|
||||||
self._raw_data = tensor.contiguous()
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f"QuantizedTensorFP8(shape={self.shape}, "
|
layout_name = self._layout_type.__name__
|
||||||
f"scale={self._scale:.4f}, dtype={self._orig_dtype})")
|
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||||
|
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layout_type(self):
|
||||||
|
return self._layout_type
|
||||||
|
|
||||||
|
def __tensor_flatten__(self):
|
||||||
|
"""
|
||||||
|
Tensor flattening protocol for proper device movement.
|
||||||
|
"""
|
||||||
|
inner_tensors = ["_q_data"]
|
||||||
|
ctx = {
|
||||||
|
"layout_type": self._layout_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor_params = {}
|
||||||
|
non_tensor_params = {}
|
||||||
|
for k, v in self._layout_params.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
tensor_params[k] = v
|
||||||
|
else:
|
||||||
|
non_tensor_params[k] = v
|
||||||
|
|
||||||
|
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
||||||
|
ctx["non_tensor_params"] = non_tensor_params
|
||||||
|
|
||||||
|
for k, v in tensor_params.items():
|
||||||
|
attr_name = f"_layout_param_{k}"
|
||||||
|
object.__setattr__(self, attr_name, v)
|
||||||
|
inner_tensors.append(attr_name)
|
||||||
|
|
||||||
|
return inner_tensors, ctx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
||||||
|
"""
|
||||||
|
Tensor unflattening protocol for proper device movement.
|
||||||
|
Reconstructs the QuantizedTensor after device movement.
|
||||||
|
"""
|
||||||
|
layout_type = ctx["layout_type"]
|
||||||
|
layout_params = dict(ctx["non_tensor_params"])
|
||||||
|
|
||||||
|
for key in ctx["tensor_param_keys"]:
|
||||||
|
attr_name = f"_layout_param_{key}"
|
||||||
|
layout_params[key] = inner_tensors[attr_name]
|
||||||
|
|
||||||
|
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn):
|
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||||
orig_dtype = tensor.dtype
|
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
|
||||||
|
return cls(qdata, layout_type, layout_params)
|
||||||
|
|
||||||
if not isinstance(scale, torch.Tensor):
|
def dequantize(self) -> torch.Tensor:
|
||||||
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
|
return self._layout_type.dequantize(self._qdata, **self._layout_params)
|
||||||
|
|
||||||
tensor_fp8 = None
|
|
||||||
if _CK_AVAILABLE:
|
|
||||||
try:
|
|
||||||
tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype)
|
|
||||||
except Exception as e:
|
|
||||||
logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}")
|
|
||||||
|
|
||||||
if tensor_fp8 is None:
|
|
||||||
lp_amax = torch.finfo(fp8_dtype).max
|
|
||||||
tensor_scaled = tensor.float() / scale
|
|
||||||
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
|
||||||
tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
|
|
||||||
|
|
||||||
return cls(tensor_fp8, scale, orig_dtype=orig_dtype)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn):
|
|
||||||
if strategy == "amax":
|
|
||||||
scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max
|
|
||||||
scale = scale.to(tensor.device, dtype=torch.float32)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown quantization strategy: {strategy}. "
|
|
||||||
f"Supported: 'amax'")
|
|
||||||
|
|
||||||
return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
# Special case: skip dispatch for internal tensor operations
|
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
||||||
# that are used for unwrapping (to avoid recursion)
|
if func in _GENERIC_UTILS:
|
||||||
if func in [torch.ops.aten._to_copy.default, torch.ops.aten.detach.default]:
|
return _GENERIC_UTILS[func](func, args, kwargs)
|
||||||
# For these ops, use the raw data to avoid recursion, but return QuantizedTensorFP8 for detach
|
|
||||||
if func == torch.ops.aten.detach.default and isinstance(args[0], QuantizedTensorFP8):
|
|
||||||
# Special handling for detach - return a new QuantizedTensorFP8
|
|
||||||
qt = args[0]
|
|
||||||
detached_data = qt._raw_data.detach()
|
|
||||||
return QuantizedTensorFP8(detached_data, qt._scale, qt._orig_dtype)
|
|
||||||
|
|
||||||
# For other ops, just unwrap
|
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
||||||
def unwrap(arg):
|
layout_type = _get_layout_from_args(args)
|
||||||
if isinstance(arg, QuantizedTensorFP8):
|
if layout_type and func in _LAYOUT_REGISTRY:
|
||||||
return arg._raw_data
|
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
||||||
return arg
|
if handler:
|
||||||
new_args = tuple(unwrap(a) if not isinstance(a, (list, tuple, dict)) else a for a in args)
|
return handler(func, args, kwargs)
|
||||||
return func(*new_args, **kwargs)
|
|
||||||
|
|
||||||
# Look up registered handler for this operation
|
# Step 3: Fallback to dequantization
|
||||||
handler = _QUANT_OP_REGISTRY.get(func)
|
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||||
if handler:
|
logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||||
return handler(func, args, kwargs)
|
|
||||||
|
|
||||||
# No handler - dequantize and use standard path
|
|
||||||
return cls._dequant_and_fallback(func, args, kwargs)
|
return cls._dequant_and_fallback(func, args, kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _dequant_and_fallback(cls, func, args, kwargs):
|
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||||
"""Fallback: dequantize all quantized tensors"""
|
|
||||||
def dequant_arg(arg):
|
def dequant_arg(arg):
|
||||||
if isinstance(arg, QuantizedTensorFP8):
|
if isinstance(arg, QuantizedTensor):
|
||||||
return arg.dequantize()
|
return arg.dequantize()
|
||||||
elif isinstance(arg, (list, tuple)):
|
elif isinstance(arg, (list, tuple)):
|
||||||
return type(arg)(dequant_arg(a) for a in arg)
|
return type(arg)(dequant_arg(a) for a in arg)
|
||||||
@ -162,74 +237,219 @@ class QuantizedTensorFP8(torch.Tensor):
|
|||||||
new_kwargs = dequant_arg(kwargs)
|
new_kwargs = dequant_arg(kwargs)
|
||||||
return func(*new_args, **new_kwargs)
|
return func(*new_args, **new_kwargs)
|
||||||
|
|
||||||
def dequantize(self) -> torch.Tensor:
|
|
||||||
plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype)
|
|
||||||
return plain_tensor * self._scale
|
|
||||||
|
|
||||||
def detach(self):
|
|
||||||
"""Detach returns a new QuantizedTensorFP8 (required for Parameter)"""
|
|
||||||
detached_data = self._raw_data.detach()
|
|
||||||
return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Operation Handlers for Quantized Tensors
|
# Generic Utilities (Layout-Agnostic Operations)
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
@register_quant_op(torch.ops.aten.linear.default)
|
def _create_transformed_qtensor(qt, transform_fn):
|
||||||
def handle_linear_fp8(func, args, kwargs):
|
new_data = transform_fn(qt._qdata)
|
||||||
"""
|
new_params = _copy_layout_params(qt._layout_params)
|
||||||
Handle F.linear() with quantized inputs.
|
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
||||||
|
|
||||||
Supports:
|
|
||||||
- QuantizedTensorFP8 input + QuantizedTensorFP8 weight
|
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||||
- QuantizedTensorFP8 input + regular weight
|
if target_dtype is not None and target_dtype != qt.dtype:
|
||||||
- Regular input + QuantizedTensorFP8 weight
|
logging.warning(
|
||||||
|
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||||
|
f"but not supported for quantized tensors. Ignoring dtype."
|
||||||
|
)
|
||||||
|
|
||||||
|
if target_layout is not None and target_layout != torch.strided:
|
||||||
|
logging.warning(
|
||||||
|
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||||
|
f"but not supported. Ignoring layout."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle device transfer
|
||||||
|
current_device = qt._qdata.device
|
||||||
|
if target_device is not None:
|
||||||
|
# Normalize device for comparison
|
||||||
|
if isinstance(target_device, str):
|
||||||
|
target_device = torch.device(target_device)
|
||||||
|
if isinstance(current_device, str):
|
||||||
|
current_device = torch.device(current_device)
|
||||||
|
|
||||||
|
if target_device != current_device:
|
||||||
|
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||||
|
new_q_data = qt._qdata.to(device=target_device)
|
||||||
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||||
|
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||||
|
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||||
|
return new_qt
|
||||||
|
|
||||||
|
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
||||||
|
return qt
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.detach.default)
|
||||||
|
def generic_detach(func, args, kwargs):
|
||||||
|
"""Detach operation - creates a detached copy of the quantized tensor."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.clone.default)
|
||||||
|
def generic_clone(func, args, kwargs):
|
||||||
|
"""Clone operation - creates a deep copy of the quantized tensor."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten._to_copy.default)
|
||||||
|
def generic_to_copy(func, args, kwargs):
|
||||||
|
"""Device/dtype transfer operation - handles .to(device) calls."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _handle_device_transfer(
|
||||||
|
qt,
|
||||||
|
target_device=kwargs.get('device', None),
|
||||||
|
target_dtype=kwargs.get('dtype', None),
|
||||||
|
op_name="_to_copy"
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
||||||
|
def generic_to_dtype_layout(func, args, kwargs):
|
||||||
|
"""Handle .to(device) calls using the dtype_layout variant."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _handle_device_transfer(
|
||||||
|
qt,
|
||||||
|
target_device=kwargs.get('device', None),
|
||||||
|
target_dtype=kwargs.get('dtype', None),
|
||||||
|
target_layout=kwargs.get('layout', None),
|
||||||
|
op_name="to"
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.copy_.default)
|
||||||
|
def generic_copy_(func, args, kwargs):
|
||||||
|
qt_dest = args[0]
|
||||||
|
src = args[1]
|
||||||
|
|
||||||
|
if isinstance(qt_dest, QuantizedTensor):
|
||||||
|
if isinstance(src, QuantizedTensor):
|
||||||
|
# Copy from another quantized tensor
|
||||||
|
qt_dest._qdata.copy_(src._qdata)
|
||||||
|
qt_dest._layout_type = src._layout_type
|
||||||
|
qt_dest._layout_params = _copy_layout_params(src._layout_params)
|
||||||
|
else:
|
||||||
|
# Copy from regular tensor - just copy raw data
|
||||||
|
qt_dest._qdata.copy_(src)
|
||||||
|
return qt_dest
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||||
|
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# FP8 Layout + Operation Handlers
|
||||||
|
# ==============================================================================
|
||||||
|
class TensorCoreFP8Layout(QuantizedLayout):
|
||||||
"""
|
"""
|
||||||
|
Storage format:
|
||||||
|
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||||
|
- scale: Scalar tensor (float32) for dequantization
|
||||||
|
- orig_dtype: Original dtype before quantization (for casting back)
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn):
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
|
if scale is None:
|
||||||
|
scale = torch.amax(tensor.abs()) / torch.finfo(fp8_dtype).max
|
||||||
|
|
||||||
|
if not isinstance(scale, torch.Tensor):
|
||||||
|
scale = torch.tensor(scale)
|
||||||
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
if _CK_AVAILABLE and tensor.device.type == "cuda":
|
||||||
|
qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype)
|
||||||
|
else:
|
||||||
|
lp_amax = torch.finfo(fp8_dtype).max
|
||||||
|
tensor_scaled = tensor.float() / scale
|
||||||
|
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||||
|
qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
layout_params = {
|
||||||
|
'scale': scale,
|
||||||
|
'orig_dtype': orig_dtype
|
||||||
|
}
|
||||||
|
return qdata, layout_params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||||
|
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||||
|
return plain_tensor * scale
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_plain_tensors(cls, qtensor):
|
||||||
|
return qtensor._qdata, qtensor._layout_params['scale']
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||||
|
def fp8_linear(func, args, kwargs):
|
||||||
input_tensor = args[0]
|
input_tensor = args[0]
|
||||||
weight = args[1]
|
weight = args[1]
|
||||||
bias = args[2] if len(args) > 2 else None
|
bias = args[2] if len(args) > 2 else None
|
||||||
out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype)
|
|
||||||
|
|
||||||
# Case 1: Both input and weight are FP8
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8):
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
# Get plain tensors to avoid dispatch recursion
|
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||||
plain_input = input_tensor._raw_data
|
|
||||||
plain_weight = weight._raw_data
|
out_dtype = kwargs.get("out_dtype")
|
||||||
weight_t = plain_weight.t() # Keep as column-major for cuBLASLt
|
if out_dtype is None:
|
||||||
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||||
|
|
||||||
|
weight_t = plain_weight.t()
|
||||||
|
|
||||||
|
tensor_2d = False
|
||||||
|
if len(plain_input.shape) == 2:
|
||||||
|
tensor_2d = True
|
||||||
|
plain_input = plain_input.unsqueeze(1)
|
||||||
|
|
||||||
|
input_shape = plain_input.shape
|
||||||
|
if len(input_shape) != 3:
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
output = torch._scaled_mm(
|
output = torch._scaled_mm(
|
||||||
plain_input,
|
plain_input.reshape(-1, input_shape[2]),
|
||||||
weight_t,
|
weight_t,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
scale_a=input_tensor._scale,
|
scale_a=scale_a,
|
||||||
scale_b=weight._scale,
|
scale_b=scale_b,
|
||||||
out_dtype=out_dtype,
|
out_dtype=out_dtype,
|
||||||
)
|
)
|
||||||
if isinstance(output, tuple):
|
if not tensor_2d:
|
||||||
output = output[0]
|
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
||||||
|
|
||||||
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
output_scale = input_tensor._scale * weight._scale
|
output_scale = scale_a * scale_b
|
||||||
return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it
|
output_params = {
|
||||||
|
'scale': output_scale,
|
||||||
|
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||||
|
}
|
||||||
|
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
|
||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||||
|
|
||||||
# Case 2: Only weight is quantized
|
# Case 2: DQ Fallback
|
||||||
if isinstance(weight, QuantizedTensorFP8):
|
if isinstance(weight, QuantizedTensor):
|
||||||
weight_dq = weight.dequantize()
|
weight = weight.dequantize()
|
||||||
input_dq = input_tensor.dequantize() if isinstance(input_tensor, QuantizedTensorFP8) else input_tensor
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
return torch.nn.functional.linear(input_dq, weight_dq, bias)
|
input_tensor = input_tensor.dequantize()
|
||||||
|
|
||||||
# Case 3: Only input is quantized
|
|
||||||
elif isinstance(input_tensor, QuantizedTensorFP8):
|
|
||||||
input_dq = input_tensor.dequantize()
|
|
||||||
return torch.nn.functional.linear(input_dq, weight, bias)
|
|
||||||
|
|
||||||
# Case 4: Neither is quantized (shouldn't happen, but handle it)
|
|
||||||
else:
|
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
|
||||||
|
|
||||||
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
|
|||||||
@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_diffusion_model_state_dict(sd, model_options={}):
|
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||||
"""
|
"""
|
||||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||||
|
|
||||||
@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||||
|
|
||||||
if model_config is not None:
|
if model_config is not None:
|
||||||
new_sd = sd
|
new_sd = sd
|
||||||
@ -1346,8 +1346,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
|
|
||||||
|
|
||||||
def load_diffusion_model(unet_path, model_options={}):
|
def load_diffusion_model(unet_path, model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(unet_path)
|
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user