diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7a3851228..feab164c6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -7,121 +7,24 @@ import logging import torch -# ============================================================================== -# Quantization Detection Functions -# ============================================================================== - -def normalize_layer_name(full_key, known_prefixes): - """ - Strip model prefix and parameter suffix from a state dict key. - - Args: - full_key: Full state dict key (e.g., "model.diffusion_model.layer1.weight") - known_prefixes: List of known model prefixes to strip - - Returns: - Normalized layer name (e.g., "layer1") - """ - name = full_key - - # Strip model prefix - for prefix in known_prefixes: - if name.startswith(prefix): - name = name[len(prefix):] - break - - # Remove parameter suffix - for suffix in [".weight", ".bias", ".scale_weight", ".scale_input"]: - if name.endswith(suffix): - name = name[:-len(suffix)] - break - - return name - - -def detect_layer_quantization(state_dict, prefix="model.diffusion_model."): - """ - Detect per-layer quantization configuration from state dict. - - Detection priority: - 1. Check for _quantization_metadata key (new format) - 2. Check for scaled_fp8 key (legacy format - return None) - 3. Check for per-layer scale_weight patterns (mixed detection) - 4. No quantization detected (return None) - - Args: - state_dict: Model state dictionary - prefix: Key prefix for model layers - - Returns: - Dict mapping layer names to quantization configs, or None for legacy/no quantization. - - Example return value: - { - "input_blocks.5.1.transformer_blocks.0.attn1.to_q": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": True} - }, - "middle_block.1.transformer_blocks.0.attn2.to_k": { - "format": "fp8_e5m2_scaled", - "params": {"use_fp8_matmul": True} - } - } - """ - - # 1. Check for new metadata format - metadata_key = f"{prefix}_quantization_metadata" - if metadata_key in state_dict: - try: - metadata = state_dict.pop(metadata_key) - if isinstance(metadata, dict) and "layers" in metadata: - logging.info(f"Found quantization metadata (version {metadata.get('format_version', 'unknown')})") - return metadata["layers"] - else: - logging.warning(f"Invalid quantization metadata format, ignoring") - except Exception as e: - logging.error(f"Failed to parse quantization metadata: {e}") - return None +def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."): + # 1. Check for per-layer config in metadata + quant_key = "_quantization_metadata" + if metadata is not None and quant_key in metadata: + quant_metadata = metadata.pop(quant_key) + quant_metadata = json.loads(quant_metadata) + if isinstance(quant_metadata, dict) and "layers" in quant_metadata: + logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") + return quant_metadata["layers"] + else: + raise ValueError(f"Invalid quantization metadata format") # 2. Check for legacy scaled_fp8 marker - # If present, return None to use legacy code path scaled_fp8_key = f"{prefix}scaled_fp8" if scaled_fp8_key in state_dict: logging.debug("Detected legacy scaled_fp8 format, using legacy code path") return None - - # 3. Check for per-layer scale patterns (mixed precision without metadata) - # Look for layers that have scale_weight but not all layers have it - known_prefixes = [prefix] - layer_configs = {} - layers_with_scale = set() - layers_with_weight = set() - - for key in state_dict.keys(): - if key.startswith(prefix): - if key.endswith(".scale_weight"): - layer_name = normalize_layer_name(key, known_prefixes) - layers_with_scale.add(layer_name) - # Detect format based on weight dtype - weight_key = f"{prefix}{layer_name}.weight" - if weight_key in state_dict: - weight_dtype = state_dict[weight_key].dtype - if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - format_name = "fp8_e4m3fn" if weight_dtype == torch.float8_e4m3fn else "fp8_e5m2" - layer_configs[layer_name] = { - "format": format_name, - "params": {} - } - elif key.endswith(".weight") and not key.endswith(".scale_weight"): - layer_name = normalize_layer_name(key, known_prefixes) - layers_with_weight.add(layer_name) - - # If we found scale_weight on some but not all layers, it's mixed precision - if layer_configs and len(layers_with_scale) < len(layers_with_weight): - logging.info(f"Detected mixed precision via scale patterns: {len(layers_with_scale)} quantized layers, {len(layers_with_weight)} total layers") - return layer_configs - - # 4. No quantization detected + return None @@ -821,7 +724,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal model_config.optimizations["fp8"] = True # Detect per-layer quantization (mixed precision) - layer_quant_config = detect_layer_quantization(state_dict, unet_key_prefix) + layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix) if layer_quant_config: model_config.layer_quant_config = layer_quant_config logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") diff --git a/comfy/ops.py b/comfy/ops.py index 3e4588706..4e24b25de 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -345,19 +345,8 @@ class manual_cast(disable_weight_init): def fp8_linear(self, input): """ - Legacy FP8 linear function - now uses tensor subclass infrastructure. - - This function maintains backward compatibility with existing code while - routing all FP8 computation through the unified tensor subclass system. - All actual FP8 matmul logic is handled by the registered operation handlers - in quant_ops.py via __torch_dispatch__. - - Args: - self: Linear layer with FP8 weight and scale parameters - input: Input tensor (any dtype) - - Returns: - Output tensor or None if weight is not FP8 + Legacy FP8 linear function for backward compatibility. + Uses QuantizedTensor subclass for dispatch. """ dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: @@ -372,10 +361,8 @@ def fp8_linear(self, input): input_dtype = input.dtype if len(input.shape) == 3: - # Get weight and bias using standard casting w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) - # Get scales (same as before) scale_weight = self.scale_weight scale_input = self.scale_input if scale_weight is None: @@ -388,14 +375,13 @@ def fp8_linear(self, input): else: scale_input = scale_input.to(input.device) - # Wrap weight in QuantizedTensorFP8 - this enables unified dispatch - quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) - quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype) - # Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! - # This is the key unification: all FP8 computation goes through one path + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} + quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - # Reshape output if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) @@ -492,183 +478,117 @@ if CUBLAS_IS_AVAILABLE: return super().forward(*args, **kwargs) -# Import quantization operations from separate module -from .quant_ops import QuantizedTensorFP8 - - # ============================================================================== # Mixed Precision Operations # ============================================================================== +from .quant_ops import QuantizedTensor, TensorCoreFP8Layout + +QUANT_FORMAT_MIXINS = { + "float8_e4m3fn": { + "dtype": torch.float8_e4m3fn, + "layout_type": TensorCoreFP8Layout, + "parameters": { + "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), + } + } +} class MixedPrecisionOps(disable_weight_init): - """ - Operations class supporting per-layer quantization (mixed precision). - - This class enables different layers to use different quantization formats - within the same model (e.g., some layers FP8, others BF16). - - Layer-specific quantization is configured via _layer_quant_config class variable, - which is set by pick_operations() when a model has mixed precision. - """ - - _layer_quant_config = {} # Class variable set by pick_operations() - - class Linear(disable_weight_init.Linear): - """Linear layer with optional per-layer quantization using tensor subclasses""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.quant_format = None - self.quant_scale = None - self._quantization_initialized = False - + _layer_quant_config = {} + _compute_dtype = torch.bfloat16 + + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} + + self.in_features = in_features + self.out_features = out_features + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) + + self.tensor_class = None + def reset_parameters(self): - # Don't allocate weights - return None like disable_weight_init return None def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - """ - Called by PyTorch during load_state_dict. - Load weight and wrap in QuantizedTensorFP8 if this layer is quantized. - """ - # Call parent to load weight and bias first - super()._load_from_state_dict( - state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs - ) - - # After weight is loaded, wrap it if this layer is quantized - if not self._quantization_initialized: - # Normalize layer name from prefix - layer_name = prefix.rstrip('.') - - # Strip known model prefixes - for model_prefix in ["model.diffusion_model.", "model.model.", "net."]: - if layer_name.startswith(model_prefix): - layer_name = layer_name[len(model_prefix):] - break - - # Check if this layer has quantization config - if layer_name in MixedPrecisionOps._layer_quant_config: - config = MixedPrecisionOps._layer_quant_config[layer_name] - self.quant_format = config.get("format", "fp8_e4m3fn") - - # Load scale parameter - scale_key = f"{prefix}scale_weight" - if scale_key in state_dict: - self.quant_scale = state_dict[scale_key] - - # Wrap weight in QuantizedTensorFP8 - if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - try: - # Determine original dtype (default to bfloat16) - orig_dtype = torch.bfloat16 - - # Wrap weight in quantized tensor subclass - quantized_weight = QuantizedTensorFP8( - self.weight.data, - self.quant_scale, - orig_dtype=orig_dtype - ) - - # Replace weight parameter with wrapped version - self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) - - logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})") - except Exception as e: - logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}") - self.quant_format = None - self.quant_scale = None - else: - logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization") - self.quant_format = None - self.quant_scale = None - else: - logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict") - self.quant_format = None - - self._quantization_initialized = True - - def _save_to_state_dict(self, destination, prefix, keep_vars): - """Save layer parameters including quantization scale""" - # First unwrap the weight if it's quantized - if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8): - # Temporarily unwrap to save the raw FP8 data - quantized_tensor = self.weight.data - raw_fp8_data = quantized_tensor._raw_data - original_weight = self.weight - self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False) - - # Call parent to save unwrapped weight - super()._save_to_state_dict(destination, prefix, keep_vars) - - # Restore the wrapped weight - self.weight = original_weight - - # Save the scale parameter - if self.quant_scale is not None: - destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach() - else: - # Standard path for non-quantized weights - super()._save_to_state_dict(destination, prefix, keep_vars) - - def forward_comfy_cast_weights(self, input): - """ - Forward pass - tensor subclass handles dispatch automatically! - __torch_dispatch__ will route to registered handlers based on tensor types. - """ - weight, bias = cast_bias_weight(self, input) - - # Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it! - return torch.nn.functional.linear(input, weight, bias) - - def forward(self, *args, **kwargs): - """Main forward pass""" - run_every_op() - # Same logic as disable_weight_init.Linear - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(*args, **kwargs) - else: - return super().forward(*args, **kwargs) - - @classmethod - def conv_nd(s, dims, *args, **kwargs): - """Create Conv layer (same as disable_weight_init)""" - if dims == 2: - return s.Conv2d(*args, **kwargs) - elif dims == 3: - return s.Conv3d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") + + if layer_name not in MixedPrecisionOps._layer_quant_config: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + else: + quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) + if quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") + + mixin = QUANT_FORMAT_MIXINS[quant_format] + self.layout_type = mixin["layout_type"] + + layout_params = { + 'scale': state_dict.pop(f"{prefix}weight_scale", None), + 'orig_dtype': MixedPrecisionOps._compute_dtype + } + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params), + requires_grad=False + ) + + for param_name, param_value in mixin["parameters"].items(): + _v = state_dict.pop(f"{prefix}{param_name}", None) + if _v is None: + continue + setattr(self, param_name, _v.to(device=device)) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) + + def forward_comfy_cast_weights(self, input): + weight, bias = cast_bias_weight(self, input) + self._forward(input, weight, bias) + + def forward(self, input, *args, **kwargs): + run_every_op() + + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + getattr(self, 'input_scale', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) + def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): - """ - Select appropriate operations class for model. - - NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3). - LEGACY: All other paths unchanged for backward compatibility. - - Args: - weight_dtype: Weight storage dtype - compute_dtype: Computation dtype - load_device: Device for loading - disable_fast_fp8: Disable fast FP8 paths - fp8_optimizations: Enable FP8 optimizations - scaled_fp8: Legacy FP8 dtype marker - model_config: Model config object (optional, for mixed precision support) - - Returns: - Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init) - """ - # NEW: Check for mixed precision + # If model_config.layer_quant_config exists, use new MixedPrecisionOps. if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config + # MixedPrecisionOps._compute_dtype = compute_dtype # TODO + MixedPrecisionOps._compute_dtype = torch.bfloat16 logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return MixedPrecisionOps - # LEGACY paths (unchanged) fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8e3bacbaf..3802da852 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,42 +1,79 @@ import torch import logging +from typing import Tuple, Dict -# ============================================================================== -# Global Operation Registry -# ============================================================================== +_LAYOUT_REGISTRY = {} +_GENERIC_UTILS = {} -# Global operation registry: torch operation → handler function -_QUANT_OP_REGISTRY = {} -def register_quant_op(torch_op): +def register_layout_op(torch_op, layout_type): """ - Decorator to register an operation handler. - + Decorator to register a layout-specific operation handler. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) + layout_type: Layout class (e.g., TensorCoreFP8Layout) Example: - @register_quant_op(torch.ops.aten.linear.default) - def handle_linear_fp8(func, args, kwargs): - # Implementation + @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) + def fp8_linear(func, args, kwargs): + # FP8-specific linear implementation ... """ def decorator(handler_func): - _QUANT_OP_REGISTRY[torch_op] = handler_func + if torch_op not in _LAYOUT_REGISTRY: + _LAYOUT_REGISTRY[torch_op] = {} + _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func return handler_func return decorator -def get_quant_handler(torch_op): - """Get registered handler for an operation""" - return _QUANT_OP_REGISTRY.get(torch_op) +def register_generic_util(torch_op): + """ + Decorator to register a generic utility that works for all layouts. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) + + Example: + @register_generic_util(torch.ops.aten.detach.default) + def generic_detach(func, args, kwargs): + # Works for any layout + ... + """ + def decorator(handler_func): + _GENERIC_UTILS[torch_op] = handler_func + return handler_func + return decorator -def list_registered_ops(): - """List all registered quantized operations""" - return list(_QUANT_OP_REGISTRY.keys()) +def _get_layout_from_args(args): + for arg in args: + if isinstance(arg, QuantizedTensor): + return arg._layout_type + elif isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, QuantizedTensor): + return item._layout_type + return None -# ============================================================================== -# comfy_kitchen Integration -# ============================================================================== +def _move_layout_params_to_device(params, device): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.to(device=device) + else: + new_params[k] = v + return new_params + + +def _copy_layout_params(params): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.clone() + else: + new_params[k] = v + return new_params + try: import comfy_kitchen as ck @@ -53,106 +90,144 @@ except Exception as e: logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks") -# ============================================================================== -# Quantized Tensor Subclass -# ============================================================================== - -class QuantizedTensorFP8(torch.Tensor): +class QuantizedLayout: """ - Tensor subclass for FP8 quantized data. - Automatically handles operations via __torch_dispatch__. + Base class for quantization layouts. + + A layout encapsulates the format-specific logic for quantization/dequantization + and provides a uniform interface for extracting raw tensors needed for computation. + + New quantization formats should subclass this and implement the required methods. + """ + @classmethod + def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: + raise NotImplementedError(f"{cls.__name__} must implement quantize()") + + @staticmethod + def dequantize(qdata, **layout_params) -> torch.Tensor: + raise NotImplementedError(f"TensorLayout must implement dequantize()") + + @classmethod + def get_plain_tensors(cls, qtensor) -> torch.Tensor: + raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") + + +class QuantizedTensor(torch.Tensor): + """ + Universal quantized tensor that works with any layout. + + This tensor subclass uses a pluggable layout system to support multiple + quantization formats (FP8, INT4, INT8, etc.) without code duplication. + + The layout_type determines format-specific behavior, while common operations + (detach, clone, to) are handled generically. + + Attributes: + _qdata: The quantized tensor data + _layout_type: Layout class (e.g., TensorCoreFP8Layout) + _layout_params: Dict with layout-specific params (scale, zero_point, etc.) """ @staticmethod - def __new__(cls, tensor, scale, orig_dtype=torch.bfloat16): + def __new__(cls, qdata, layout_type, layout_params): """ - Create a quantized FP8 tensor. + Create a quantized tensor. Args: - tensor: The FP8 tensor data (torch.float8_e4m3fn or e5m2) - scale: Scale factor for dequantization (scalar tensor) - orig_dtype: Original dtype before quantization + qdata: The quantized data tensor + layout_type: Layout class (subclass of QuantizedLayout) + layout_params: Dict with layout-specific parameters """ - return torch.Tensor._make_subclass(cls, tensor, require_grad=False) + return torch.Tensor._make_subclass(cls, qdata, require_grad=False) - def __init__(self, tensor, scale, orig_dtype=torch.bfloat16): - self._scale = scale - self._orig_dtype = orig_dtype - # Store a reference to prevent infinite recursion in dequantize - self._raw_data = tensor.contiguous() + def __init__(self, qdata, layout_type, layout_params): + self._qdata = qdata.contiguous() + self._layout_type = layout_type + self._layout_params = layout_params def __repr__(self): - return (f"QuantizedTensorFP8(shape={self.shape}, " - f"scale={self._scale:.4f}, dtype={self._orig_dtype})") + layout_name = self._layout_type.__name__ + param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) + return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" + + @property + def layout_type(self): + return self._layout_type + + def __tensor_flatten__(self): + """ + Tensor flattening protocol for proper device movement. + """ + inner_tensors = ["_q_data"] + ctx = { + "layout_type": self._layout_type, + } + + tensor_params = {} + non_tensor_params = {} + for k, v in self._layout_params.items(): + if isinstance(v, torch.Tensor): + tensor_params[k] = v + else: + non_tensor_params[k] = v + + ctx["tensor_param_keys"] = list(tensor_params.keys()) + ctx["non_tensor_params"] = non_tensor_params + + for k, v in tensor_params.items(): + attr_name = f"_layout_param_{k}" + object.__setattr__(self, attr_name, v) + inner_tensors.append(attr_name) + + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): + """ + Tensor unflattening protocol for proper device movement. + Reconstructs the QuantizedTensor after device movement. + """ + layout_type = ctx["layout_type"] + layout_params = dict(ctx["non_tensor_params"]) + + for key in ctx["tensor_param_keys"]: + attr_name = f"_layout_param_{key}" + layout_params[key] = inner_tensors[attr_name] + + return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) @classmethod - def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn): - orig_dtype = tensor.dtype - - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) - - tensor_fp8 = None - if _CK_AVAILABLE: - try: - tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) - except Exception as e: - logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}") - - if tensor_fp8 is None: - lp_amax = torch.finfo(fp8_dtype).max - tensor_scaled = tensor.float() / scale - torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) - tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) - - return cls(tensor_fp8, scale, orig_dtype=orig_dtype) + def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': + qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) + return cls(qdata, layout_type, layout_params) - @classmethod - def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn): - if strategy == "amax": - scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max - scale = scale.to(tensor.device, dtype=torch.float32) - else: - raise ValueError(f"Unknown quantization strategy: {strategy}. " - f"Supported: 'amax'") - - return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype) + def dequantize(self) -> torch.Tensor: + return self._layout_type.dequantize(self._qdata, **self._layout_params) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} - # Special case: skip dispatch for internal tensor operations - # that are used for unwrapping (to avoid recursion) - if func in [torch.ops.aten._to_copy.default, torch.ops.aten.detach.default]: - # For these ops, use the raw data to avoid recursion, but return QuantizedTensorFP8 for detach - if func == torch.ops.aten.detach.default and isinstance(args[0], QuantizedTensorFP8): - # Special handling for detach - return a new QuantizedTensorFP8 - qt = args[0] - detached_data = qt._raw_data.detach() - return QuantizedTensorFP8(detached_data, qt._scale, qt._orig_dtype) - - # For other ops, just unwrap - def unwrap(arg): - if isinstance(arg, QuantizedTensorFP8): - return arg._raw_data - return arg - new_args = tuple(unwrap(a) if not isinstance(a, (list, tuple, dict)) else a for a in args) - return func(*new_args, **kwargs) + # Step 1: Check generic utilities first (detach, clone, to, etc.) + if func in _GENERIC_UTILS: + return _GENERIC_UTILS[func](func, args, kwargs) - # Look up registered handler for this operation - handler = _QUANT_OP_REGISTRY.get(func) - if handler: - return handler(func, args, kwargs) + # Step 2: Check layout-specific handlers (linear, matmul, etc.) + layout_type = _get_layout_from_args(args) + if layout_type and func in _LAYOUT_REGISTRY: + handler = _LAYOUT_REGISTRY[func].get(layout_type) + if handler: + return handler(func, args, kwargs) - # No handler - dequantize and use standard path + # Step 3: Fallback to dequantization + if isinstance(args[0] if args else None, QuantizedTensor): + logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") return cls._dequant_and_fallback(func, args, kwargs) @classmethod def _dequant_and_fallback(cls, func, args, kwargs): - """Fallback: dequantize all quantized tensors""" def dequant_arg(arg): - if isinstance(arg, QuantizedTensorFP8): + if isinstance(arg, QuantizedTensor): return arg.dequantize() elif isinstance(arg, (list, tuple)): return type(arg)(dequant_arg(a) for a in arg) @@ -161,75 +236,220 @@ class QuantizedTensorFP8(torch.Tensor): new_args = dequant_arg(args) new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) - - def dequantize(self) -> torch.Tensor: - plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype) - return plain_tensor * self._scale - - def detach(self): - """Detach returns a new QuantizedTensorFP8 (required for Parameter)""" - detached_data = self._raw_data.detach() - return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype) # ============================================================================== -# Operation Handlers for Quantized Tensors +# Generic Utilities (Layout-Agnostic Operations) # ============================================================================== -@register_quant_op(torch.ops.aten.linear.default) -def handle_linear_fp8(func, args, kwargs): - """ - Handle F.linear() with quantized inputs. +def _create_transformed_qtensor(qt, transform_fn): + new_data = transform_fn(qt._qdata) + new_params = _copy_layout_params(qt._layout_params) + return QuantizedTensor(new_data, qt._layout_type, new_params) + + +def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): + if target_dtype is not None and target_dtype != qt.dtype: + logging.warning( + f"QuantizedTensor: dtype conversion requested to {target_dtype}, " + f"but not supported for quantized tensors. Ignoring dtype." + ) - Supports: - - QuantizedTensorFP8 input + QuantizedTensorFP8 weight - - QuantizedTensorFP8 input + regular weight - - Regular input + QuantizedTensorFP8 weight + if target_layout is not None and target_layout != torch.strided: + logging.warning( + f"QuantizedTensor: layout change requested to {target_layout}, " + f"but not supported. Ignoring layout." + ) + + # Handle device transfer + current_device = qt._qdata.device + if target_device is not None: + # Normalize device for comparison + if isinstance(target_device, str): + target_device = torch.device(target_device) + if isinstance(current_device, str): + current_device = torch.device(current_device) + + if target_device != current_device: + logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") + new_q_data = qt._qdata.to(device=target_device) + new_params = _move_layout_params_to_device(qt._layout_params, target_device) + new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) + logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") + return new_qt + + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") + return qt + + +@register_generic_util(torch.ops.aten.detach.default) +def generic_detach(func, args, kwargs): + """Detach operation - creates a detached copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.detach()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.clone.default) +def generic_clone(func, args, kwargs): + """Clone operation - creates a deep copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.clone()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._to_copy.default) +def generic_to_copy(func, args, kwargs): + """Device/dtype transfer operation - handles .to(device) calls.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + op_name="_to_copy" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.to.dtype_layout) +def generic_to_dtype_layout(func, args, kwargs): + """Handle .to(device) calls using the dtype_layout variant.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + target_layout=kwargs.get('layout', None), + op_name="to" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.copy_.default) +def generic_copy_(func, args, kwargs): + qt_dest = args[0] + src = args[1] + + if isinstance(qt_dest, QuantizedTensor): + if isinstance(src, QuantizedTensor): + # Copy from another quantized tensor + qt_dest._qdata.copy_(src._qdata) + qt_dest._layout_type = src._layout_type + qt_dest._layout_params = _copy_layout_params(src._layout_params) + else: + # Copy from regular tensor - just copy raw data + qt_dest._qdata.copy_(src) + return qt_dest + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) +def generic_has_compatible_shallow_copy_type(func, args, kwargs): + return True + +# ============================================================================== +# FP8 Layout + Operation Handlers +# ============================================================================== +class TensorCoreFP8Layout(QuantizedLayout): """ + Storage format: + - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) + - scale: Scalar tensor (float32) for dequantization + - orig_dtype: Original dtype before quantization (for casting back) + """ + @classmethod + def quantize(cls, tensor, scale=None, fp8_dtype=torch.float8_e4m3fn): + orig_dtype = tensor.dtype + + if scale is None: + scale = torch.amax(tensor.abs()) / torch.finfo(fp8_dtype).max + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + if _CK_AVAILABLE and tensor.device.type == "cuda": + qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) + else: + lp_amax = torch.finfo(fp8_dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + + layout_params = { + 'scale': scale, + 'orig_dtype': orig_dtype + } + return qdata, layout_params + + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) + return plain_tensor * scale + + @classmethod + def get_plain_tensors(cls, qtensor): + return qtensor._qdata, qtensor._layout_params['scale'] + + +@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) +def fp8_linear(func, args, kwargs): input_tensor = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype) - - # Case 1: Both input and weight are FP8 - if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): - # Get plain tensors to avoid dispatch recursion - plain_input = input_tensor._raw_data - plain_weight = weight._raw_data - weight_t = plain_weight.t() # Keep as column-major for cuBLASLt + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + out_dtype = kwargs.get("out_dtype") + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + weight_t = plain_weight.t() + + tensor_2d = False + if len(plain_input.shape) == 2: + tensor_2d = True + plain_input = plain_input.unsqueeze(1) + + input_shape = plain_input.shape + if len(input_shape) != 3: + return None + try: output = torch._scaled_mm( - plain_input, + plain_input.reshape(-1, input_shape[2]), weight_t, bias=bias, - scale_a=input_tensor._scale, - scale_b=weight._scale, + scale_a=scale_a, + scale_b=scale_b, out_dtype=out_dtype, ) - if isinstance(output, tuple): - output = output[0] - + if not tensor_2d: + output = output.reshape((-1, input_shape[1], weight.shape[0])) + if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - output_scale = input_tensor._scale * weight._scale - return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it + output_scale = scale_a * scale_b + output_params = { + 'scale': output_scale, + 'orig_dtype': input_tensor._layout_params['orig_dtype'] + } + return QuantizedTensor(output, TensorCoreFP8Layout, output_params) else: return output + except Exception as e: - logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - # Case 2: Only weight is quantized - if isinstance(weight, QuantizedTensorFP8): - weight_dq = weight.dequantize() - input_dq = input_tensor.dequantize() if isinstance(input_tensor, QuantizedTensorFP8) else input_tensor - return torch.nn.functional.linear(input_dq, weight_dq, bias) - - # Case 3: Only input is quantized - elif isinstance(input_tensor, QuantizedTensorFP8): - input_dq = input_tensor.dequantize() - return torch.nn.functional.linear(input_dq, weight, bias) - - # Case 4: Neither is quantized (shouldn't happen, but handle it) - else: - return torch.nn.functional.linear(input_tensor, weight, bias) + # Case 2: DQ Fallback + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight, bias) diff --git a/comfy/sd.py b/comfy/sd.py index 28bee248d..b965e9842 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1262,7 +1262,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}): +def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): """ Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. @@ -1296,7 +1296,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): weight_dtype = comfy.utils.weight_dtype(sd) load_device = model_management.get_torch_device() - model_config = model_detection.model_config_from_unet(sd, "") + model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) if model_config is not None: new_sd = sd @@ -1346,8 +1346,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): - sd = comfy.utils.load_torch_file(unet_path) - model = load_diffusion_model_state_dict(sd, model_options=model_options) + sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) + model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))