""" This file is part of ComfyUI. Copyright (C) 2024 Stability AI This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . """ import torch import logging import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import contextlib def run_every_op(): comfy.model_management.throw_exception_if_processing_interrupted() def scaled_dot_product_attention(q, k, v, *args, **kwargs): return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) try: if torch.cuda.is_available(): from torch.nn.attention import SDPBackend, sdpa_kernel import inspect if "set_priority" in inspect.signature(sdpa_kernel).parameters: SDPA_BACKEND_PRIORITY = [ SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH, ] SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) def scaled_dot_product_attention(q, k, v, *args, **kwargs): with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) else: logging.warning("Torch version too old to set sdpa backend priority.") except (ModuleNotFoundError, TypeError): logging.warning("Could not set sdpa backend priority.") cast_to = comfy.model_management.cast_to #TODO: remove once no more references if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast: torch.backends.cudnn.benchmark = True def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: if dtype is None: dtype = input.dtype if bias_dtype is None: bias_dtype = dtype if device is None: device = input.device offload_stream = comfy.model_management.get_offload_stream(device) if offload_stream is not None: wf_context = offload_stream else: wf_context = contextlib.nullcontext() bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: has_function = len(s.bias_function) > 0 bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: with wf_context: for f in s.bias_function: bias = f(bias) has_function = len(s.weight_function) > 0 weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: with wf_context: for f in s.weight_function: weight = f(weight) comfy.model_management.sync_stream(device, offload_stream) return weight, bias class CastWeightBiasOp: comfy_cast_weights = False weight_function = [] bias_function = [] class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class Conv1d(torch.nn.Conv1d, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class Conv2d(torch.nn.Conv2d, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class Conv3d(torch.nn.Conv3d, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, *args, **kwargs): run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): if self.weight is not None: weight, bias = cast_bias_weight(self, input) else: weight = None bias = None return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp): def reset_parameters(self): self.bias = None return None def forward_comfy_cast_weights(self, input): if self.weight is not None: weight, bias = cast_bias_weight(self, input) else: weight = None return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) def forward(self, *args, **kwargs): run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input, output_size=None): num_spatial_dims = 2 output_padding = self._output_padding( input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) weight, bias = cast_bias_weight(self, input) return torch.nn.functional.conv_transpose2d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) def forward(self, *args, **kwargs): run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input, output_size=None): num_spatial_dims = 1 output_padding = self._output_padding( input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) weight, bias = cast_bias_weight(self, input) return torch.nn.functional.conv_transpose1d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) def forward(self, *args, **kwargs): run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class Embedding(torch.nn.Embedding, CastWeightBiasOp): def reset_parameters(self): self.bias = None return None def forward_comfy_cast_weights(self, input, out_dtype=None): output_dtype = out_dtype if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: out_dtype = None weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) def forward(self, *args, **kwargs): run_every_op() if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: if "out_dtype" in kwargs: kwargs.pop("out_dtype") return super().forward(*args, **kwargs) @classmethod def conv_nd(s, dims, *args, **kwargs): if dims == 2: return s.Conv2d(*args, **kwargs) elif dims == 3: return s.Conv3d(*args, **kwargs) else: raise ValueError(f"unsupported dimensions: {dims}") class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): comfy_cast_weights = True class Conv1d(disable_weight_init.Conv1d): comfy_cast_weights = True class Conv2d(disable_weight_init.Conv2d): comfy_cast_weights = True class Conv3d(disable_weight_init.Conv3d): comfy_cast_weights = True class GroupNorm(disable_weight_init.GroupNorm): comfy_cast_weights = True class LayerNorm(disable_weight_init.LayerNorm): comfy_cast_weights = True class ConvTranspose2d(disable_weight_init.ConvTranspose2d): comfy_cast_weights = True class ConvTranspose1d(disable_weight_init.ConvTranspose1d): comfy_cast_weights = True class RMSNorm(disable_weight_init.RMSNorm): comfy_cast_weights = True class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True def fp8_linear(self, input): dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: return None tensor_2d = False if len(input.shape) == 2: tensor_2d = True input = input.unsqueeze(1) input_shape = input.shape input_dtype = input.dtype if len(input.shape) == 3: w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) w = w.t() scale_weight = self.scale_weight scale_input = self.scale_input if scale_weight is None: scale_weight = torch.ones((), device=input.device, dtype=torch.float32) else: scale_weight = scale_weight.to(input.device) if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) input = torch.clamp(input, min=-448, max=448, out=input) input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() else: scale_input = scale_input.to(input.device) input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() if bias is not None: o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) else: o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) if isinstance(o, tuple): o = o[0] if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) return None class fp8_ops(manual_cast): class Linear(manual_cast.Linear): def reset_parameters(self): self.scale_weight = None self.scale_input = None return None def forward_comfy_cast_weights(self, input): if not self.training: try: out = fp8_linear(self, input) if out is not None: return out except Exception as e: logging.info("Exception during fp8 op: {}".format(e)) weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) class scaled_fp8_op(manual_cast): class Linear(manual_cast.Linear): def __init__(self, *args, **kwargs): if override_dtype is not None: kwargs['dtype'] = override_dtype super().__init__(*args, **kwargs) def reset_parameters(self): if not hasattr(self, 'scale_weight'): self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) if not scale_input: self.scale_input = None if not hasattr(self, 'scale_input'): self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) return None def forward_comfy_cast_weights(self, input): if fp8_matrix_mult: out = fp8_linear(self, input) if out is not None: return out weight, bias = cast_bias_weight(self, input) if weight.numel() < input.numel(): #TODO: optimize return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) else: return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) def convert_weight(self, weight, inplace=False, **kwargs): if inplace: weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) return weight else: return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) if return_weight: return weight if inplace_update: self.weight.data.copy_(weight) else: self.weight = torch.nn.Parameter(weight, requires_grad=False) return scaled_fp8_op CUBLAS_IS_AVAILABLE = False try: from cublas_ops import CublasLinear CUBLAS_IS_AVAILABLE = True except ImportError: pass if CUBLAS_IS_AVAILABLE: class cublas_ops(disable_weight_init): class Linear(CublasLinear, disable_weight_init.Linear): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): return super().forward(input) def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) # ============================================================================== # Quantization Format Registry System # ============================================================================== class QuantFormatHandler: """ Base class for all quantization format handlers. A handler encapsulates the logic for a specific quantization format (e.g., FP8 scaled, MX formats) and manages the quantization parameters and forward pass for quantized layers. """ def __init__(self, layer, **config): """ Initialize handler for a specific layer. Args: layer: The nn.Module layer (Linear, Conv2d, etc.) **config: Format-specific configuration """ self.layer = layer self.config = config def setup_parameters(self): """ Initialize quantization parameters on the layer. Called during layer construction or load_state_dict. Subclasses should create parameters like scale_weight, scale_input, etc. and attach them to self.layer. """ raise NotImplementedError(f"{self.__class__.__name__} must implement setup_parameters()") def forward(self, *args, **kwargs): """ Execute quantized forward pass. Signature matches the layer's expected forward pass. Handler accesses layer parameters via self.layer (weight, bias, etc.) Args: *args: Positional arguments matching layer forward signature **kwargs: Keyword arguments matching layer forward signature Returns: Layer output tensor Examples: Linear: forward(input) Conv2d: forward(input) GroupNorm: forward(input) """ raise NotImplementedError(f"{self.__class__.__name__} must implement forward()") def load_state_dict(self, state_dict, prefix): """ Load quantization parameters from state dict. Args: state_dict: State dictionary prefix: Key prefix for this layer (e.g., "model.diffusion_model.layer1.") """ raise NotImplementedError(f"{self.__class__.__name__} must implement load_state_dict()") def state_dict(self, prefix): """ Save quantization parameters to state dict. Args: prefix: Key prefix for this layer Returns: Dictionary of quantization parameters with full keys """ raise NotImplementedError(f"{self.__class__.__name__} must implement state_dict()") def convert_weight(self, weight, inplace=False): """ Convert weight from quantized to full precision (dequantize). Args: weight: Quantized weight tensor inplace: Whether to modify in-place Returns: Dequantized weight tensor """ raise NotImplementedError(f"{self.__class__.__name__} must implement convert_weight()") def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): """ Convert and set weight from full precision to quantized. Args: weight: Full precision weight tensor inplace_update: Whether to update layer weight in-place seed: Random seed for stochastic rounding return_weight: If True, return quantized weight without setting Returns: Quantized weight if return_weight=True, else None """ raise NotImplementedError(f"{self.__class__.__name__} must implement set_weight()") class QuantFormatRegistry: """ Global registry for quantization formats. Formats are registered with a unique name and handler class. Custom formats can be registered by custom nodes. """ _formats = {} @classmethod def register(cls, name, handler_class, **default_config): """ Register a new quantization format. Args: name: Unique format identifier (e.g., "fp8_e4m3fn_scaled") handler_class: Handler class implementing QuantFormatHandler **default_config: Default configuration parameters Example: QuantFormatRegistry.register( "fp8_e4m3fn_scaled", handler_class=FP8ScaledHandler, base_dtype=torch.float8_e4m3fn, quantize_activation=False, use_fp8_matmul=True, ) """ if not issubclass(handler_class, QuantFormatHandler): raise TypeError(f"handler_class must be a subclass of QuantFormatHandler, got {handler_class}") cls._formats[name] = { "handler": handler_class, "config": default_config.copy() } logging.debug(f"Registered quantization format: {name}") @classmethod def get(cls, name, **override_config): """ Get format info with optional config overrides. Args: name: Format identifier **override_config: Configuration overrides Returns: Dict with 'handler' (class) and 'config' (dict) keys Raises: ValueError: If format name not registered """ if name not in cls._formats: available = ", ".join(cls._formats.keys()) if cls._formats else "none" raise ValueError(f"Unknown quantization format: '{name}'. Available formats: {available}") format_info = cls._formats[name].copy() # Merge override_config into default config config = format_info["config"].copy() config.update(override_config) format_info["config"] = config return format_info @classmethod def list_formats(cls): """List all registered format names""" return list(cls._formats.keys()) @classmethod def is_registered(cls, name): """Check if a format is registered""" return name in cls._formats class FP8ScaledHandler(QuantFormatHandler): """ Handler for FP8 quantization with per-tensor scaling. Supports both weight-only and weight+activation quantization. Compatible with existing fp8_linear implementation. """ def setup_parameters(self): """Initialize scale_weight and optionally scale_input""" device = self.layer.weight.device dtype = torch.float32 # Always have scale_weight for FP8 if not hasattr(self.layer, 'scale_weight') or self.layer.scale_weight is None: self.layer.scale_weight = torch.nn.Parameter( torch.ones((), device=device, dtype=dtype), requires_grad=False ) # scale_input is optional (for activation quantization) if self.config.get("quantize_activation", False): if not hasattr(self.layer, 'scale_input') or self.layer.scale_input is None: self.layer.scale_input = torch.nn.Parameter( torch.ones((), device=device, dtype=dtype), requires_grad=False ) else: self.layer.scale_input = None def forward(self, *args, **kwargs): """ FP8 forward pass with optional activation quantization. Supports Linear layers (Conv2d in future). """ # Detect layer type and dispatch if isinstance(self.layer, torch.nn.Linear): return self._forward_linear(*args, **kwargs) else: raise NotImplementedError( f"FP8ScaledHandler not implemented for {type(self.layer).__name__}" ) def _forward_linear(self, input): """FP8 forward for Linear layers""" # Try fast path with fp8_linear if enabled if self.config.get("use_fp8_matmul", False) and not self.layer.training: try: result = fp8_linear(self.layer, input) if result is not None: return result except Exception as e: logging.debug(f"FP8 matmul failed, falling back to standard path: {e}") # Standard path: dequantize and compute weight, bias = cast_bias_weight(self.layer, input) # Dequantize weight scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) # Apply weight functions (LoRA, etc.) - they see dequantized weights if hasattr(self.layer, 'weight_function') and len(self.layer.weight_function) > 0: weight = weight * scale for f in self.layer.weight_function: weight = f(weight) else: weight = weight * scale if hasattr(self.layer, 'bias_function') and len(self.layer.bias_function) > 0: for f in self.layer.bias_function: bias = f(bias) if bias is not None else None # Execute linear operation # Optimization: multiply by scale on smaller tensor if weight.numel() < input.numel() and len(self.layer.weight_function) == 0: return torch.nn.functional.linear(input, weight, bias) else: return torch.nn.functional.linear(input, weight, bias) def load_state_dict(self, state_dict, prefix): """Load scale parameters from state dict""" scale_weight_key = f"{prefix}scale_weight" if scale_weight_key in state_dict: self.layer.scale_weight.data.copy_(state_dict[scale_weight_key]) scale_input_key = f"{prefix}scale_input" if scale_input_key in state_dict and self.layer.scale_input is not None: self.layer.scale_input.data.copy_(state_dict[scale_input_key]) def state_dict(self, prefix): """Save scale parameters to state dict""" result = {f"{prefix}scale_weight": self.layer.scale_weight} if self.layer.scale_input is not None: result[f"{prefix}scale_input"] = self.layer.scale_input return result def convert_weight(self, weight, inplace=False): """Dequantize: multiply by scale""" scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) if inplace: weight *= scale return weight return weight * scale def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False): """Quantize: divide by scale with stochastic rounding""" scale = self.layer.scale_weight.to(device=weight.device, dtype=weight.dtype) quantized = comfy.float.stochastic_rounding( weight / scale, self.layer.weight.dtype, seed=seed ) if return_weight: return quantized if inplace_update: self.layer.weight.data.copy_(quantized) else: self.layer.weight = torch.nn.Parameter(quantized, requires_grad=False) # ============================================================================== # Mixed Precision Operations # ============================================================================== class MixedPrecisionOps(disable_weight_init): """ Operations class supporting per-layer quantization (mixed precision). This class enables different layers to use different quantization formats within the same model (e.g., some layers FP8, others BF16). Layer-specific quantization is configured via _layer_quant_config class variable, which is set by pick_operations() when a model has mixed precision. """ _layer_quant_config = {} # Class variable set by pick_operations() class Linear(disable_weight_init.Linear): """Linear layer with optional per-layer quantization""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.quant_handler = None self._handler_initialized = False def reset_parameters(self): # Don't allocate weights - return None like disable_weight_init return None def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): """ Called by PyTorch during load_state_dict. This is where we initialize the handler since we now know the layer name. """ if not self._handler_initialized: # Normalize layer name from prefix layer_name = prefix.rstrip('.') # Strip known model prefixes for model_prefix in ["model.diffusion_model.", "model.model.", "net."]: if layer_name.startswith(model_prefix): layer_name = layer_name[len(model_prefix):] break # Check if this layer has quantization config # Access via parent class since _layer_quant_config is a class variable if layer_name in MixedPrecisionOps._layer_quant_config: config = MixedPrecisionOps._layer_quant_config[layer_name] try: format_info = QuantFormatRegistry.get( config["format"], **config.get("params", {}) ) # Initialize handler self.quant_handler = format_info["handler"](self, **format_info["config"]) self.quant_handler.setup_parameters() # Let handler load its parameters (scale_weight, etc.) self.quant_handler.load_state_dict(state_dict, prefix) logging.debug(f"Initialized {config['format']} handler for layer {layer_name}") except ValueError as e: # Format not registered - fall back to standard precision logging.warning( f"Quantization format '{config['format']}' not registered for layer {layer_name}. " f"Falling back to standard precision. Error: {e}" ) self.quant_handler = None except Exception as e: logging.error(f"Failed to initialize quantization handler for {layer_name}: {e}") self.quant_handler = None self._handler_initialized = True # Call parent to load weight and bias super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) def _save_to_state_dict(self, destination, prefix, keep_vars): """Save layer parameters including quantization metadata""" super()._save_to_state_dict(destination, prefix, keep_vars) # Save handler parameters (scale_weight, etc.) if self.quant_handler: handler_dict = self.quant_handler.state_dict(prefix) destination.update(handler_dict) def forward_comfy_cast_weights(self, input): """Forward pass with optional quantization""" if self.quant_handler: # Use handler for quantized forward return self.quant_handler.forward(input) else: # Standard path for non-quantized layers weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): """Main forward pass""" run_every_op() # Same logic as disable_weight_init.Linear if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @classmethod def conv_nd(s, dims, *args, **kwargs): """Create Conv layer (same as disable_weight_init)""" if dims == 2: return s.Conv2d(*args, **kwargs) elif dims == 3: return s.Conv3d(*args, **kwargs) else: raise ValueError(f"unsupported dimensions: {dims}") def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): """ Select appropriate operations class for model. NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3). LEGACY: All other paths unchanged for backward compatibility. Args: weight_dtype: Weight storage dtype compute_dtype: Computation dtype load_device: Device for loading disable_fast_fp8: Disable fast FP8 paths fp8_optimizations: Enable FP8 optimizations scaled_fp8: Legacy FP8 dtype marker model_config: Model config object (optional, for mixed precision support) Returns: Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init) """ # NEW: Check for mixed precision if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return MixedPrecisionOps # LEGACY paths (unchanged) fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) if ( fp8_compute and (fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and not disable_fast_fp8 ): return fp8_ops if ( PerformanceFeature.CublasOps in args.fast and CUBLAS_IS_AVAILABLE and weight_dtype == torch.float16 and (compute_dtype == torch.float16 or compute_dtype is None) ): logging.info("Using cublas ops") return cublas_ops if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init return manual_cast # ============================================================================== # Register built-in quantization formats # ============================================================================== # FP8 E4M3FN weight-only quantization QuantFormatRegistry.register( "fp8_e4m3fn_scaled", handler_class=FP8ScaledHandler, base_dtype=torch.float8_e4m3fn, quantize_activation=False, use_fp8_matmul=True, ) # FP8 E4M3FN weight+activation quantization QuantFormatRegistry.register( "fp8_e4m3fn_scaled_dynamic", handler_class=FP8ScaledHandler, base_dtype=torch.float8_e4m3fn, quantize_activation=True, use_fp8_matmul=True, ) # FP8 E5M2 weight-only quantization QuantFormatRegistry.register( "fp8_e5m2_scaled", handler_class=FP8ScaledHandler, base_dtype=torch.float8_e5m2, quantize_activation=False, use_fp8_matmul=True, )