mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
676 lines
28 KiB
Python
676 lines
28 KiB
Python
"""
|
|
This file is part of ComfyUI.
|
|
Copyright (C) 2024 Stability AI
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
import torch
|
|
import logging
|
|
import comfy.model_management
|
|
from comfy.cli_args import args, PerformanceFeature
|
|
import comfy.float
|
|
import comfy.rmsnorm
|
|
import contextlib
|
|
|
|
def run_every_op():
|
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
|
|
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
|
|
|
|
|
try:
|
|
if torch.cuda.is_available():
|
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
import inspect
|
|
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
|
SDPA_BACKEND_PRIORITY = [
|
|
SDPBackend.FLASH_ATTENTION,
|
|
SDPBackend.EFFICIENT_ATTENTION,
|
|
SDPBackend.MATH,
|
|
]
|
|
|
|
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
|
|
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
|
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
|
else:
|
|
logging.warning("Torch version too old to set sdpa backend priority.")
|
|
except (ModuleNotFoundError, TypeError):
|
|
logging.warning("Could not set sdpa backend priority.")
|
|
|
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
|
|
|
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
|
|
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|
if input is not None:
|
|
if dtype is None:
|
|
dtype = input.dtype
|
|
if bias_dtype is None:
|
|
bias_dtype = dtype
|
|
if device is None:
|
|
device = input.device
|
|
|
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
|
if offload_stream is not None:
|
|
wf_context = offload_stream
|
|
else:
|
|
wf_context = contextlib.nullcontext()
|
|
|
|
bias = None
|
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
|
if s.bias is not None:
|
|
has_function = len(s.bias_function) > 0
|
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
|
|
|
if has_function:
|
|
with wf_context:
|
|
for f in s.bias_function:
|
|
bias = f(bias)
|
|
|
|
has_function = len(s.weight_function) > 0
|
|
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
|
if has_function:
|
|
with wf_context:
|
|
for f in s.weight_function:
|
|
weight = f(weight)
|
|
|
|
comfy.model_management.sync_stream(device, offload_stream)
|
|
return weight, bias
|
|
|
|
class CastWeightBiasOp:
|
|
comfy_cast_weights = False
|
|
weight_function = []
|
|
bias_function = []
|
|
|
|
class disable_weight_init:
|
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias = cast_bias_weight(self, input)
|
|
return torch.nn.functional.linear(input, weight, bias)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class Conv1d(torch.nn.Conv1d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias = cast_bias_weight(self, input)
|
|
return self._conv_forward(input, weight, bias)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class Conv2d(torch.nn.Conv2d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias = cast_bias_weight(self, input)
|
|
return self._conv_forward(input, weight, bias)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class Conv3d(torch.nn.Conv3d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias = cast_bias_weight(self, input)
|
|
return self._conv_forward(input, weight, bias)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
weight, bias = cast_bias_weight(self, input)
|
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
if self.weight is not None:
|
|
weight, bias = cast_bias_weight(self, input)
|
|
else:
|
|
weight = None
|
|
bias = None
|
|
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
self.bias = None
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
if self.weight is not None:
|
|
weight, bias = cast_bias_weight(self, input)
|
|
else:
|
|
weight = None
|
|
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
|
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input, output_size=None):
|
|
num_spatial_dims = 2
|
|
output_padding = self._output_padding(
|
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
|
num_spatial_dims, self.dilation)
|
|
|
|
weight, bias = cast_bias_weight(self, input)
|
|
return torch.nn.functional.conv_transpose2d(
|
|
input, weight, bias, self.stride, self.padding,
|
|
output_padding, self.groups, self.dilation)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input, output_size=None):
|
|
num_spatial_dims = 1
|
|
output_padding = self._output_padding(
|
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
|
num_spatial_dims, self.dilation)
|
|
|
|
weight, bias = cast_bias_weight(self, input)
|
|
return torch.nn.functional.conv_transpose1d(
|
|
input, weight, bias, self.stride, self.padding,
|
|
output_padding, self.groups, self.dilation)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
|
def reset_parameters(self):
|
|
self.bias = None
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
|
output_dtype = out_dtype
|
|
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
|
out_dtype = None
|
|
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
|
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
run_every_op()
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
if "out_dtype" in kwargs:
|
|
kwargs.pop("out_dtype")
|
|
return super().forward(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def conv_nd(s, dims, *args, **kwargs):
|
|
if dims == 2:
|
|
return s.Conv2d(*args, **kwargs)
|
|
elif dims == 3:
|
|
return s.Conv3d(*args, **kwargs)
|
|
else:
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
class manual_cast(disable_weight_init):
|
|
class Linear(disable_weight_init.Linear):
|
|
comfy_cast_weights = True
|
|
|
|
class Conv1d(disable_weight_init.Conv1d):
|
|
comfy_cast_weights = True
|
|
|
|
class Conv2d(disable_weight_init.Conv2d):
|
|
comfy_cast_weights = True
|
|
|
|
class Conv3d(disable_weight_init.Conv3d):
|
|
comfy_cast_weights = True
|
|
|
|
class GroupNorm(disable_weight_init.GroupNorm):
|
|
comfy_cast_weights = True
|
|
|
|
class LayerNorm(disable_weight_init.LayerNorm):
|
|
comfy_cast_weights = True
|
|
|
|
class ConvTranspose2d(disable_weight_init.ConvTranspose2d):
|
|
comfy_cast_weights = True
|
|
|
|
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
|
|
comfy_cast_weights = True
|
|
|
|
class RMSNorm(disable_weight_init.RMSNorm):
|
|
comfy_cast_weights = True
|
|
|
|
class Embedding(disable_weight_init.Embedding):
|
|
comfy_cast_weights = True
|
|
|
|
|
|
def fp8_linear(self, input):
|
|
"""
|
|
Legacy FP8 linear function - now uses tensor subclass infrastructure.
|
|
|
|
This function maintains backward compatibility with existing code while
|
|
routing all FP8 computation through the unified tensor subclass system.
|
|
All actual FP8 matmul logic is handled by the registered operation handlers
|
|
in quant_ops.py via __torch_dispatch__.
|
|
|
|
Args:
|
|
self: Linear layer with FP8 weight and scale parameters
|
|
input: Input tensor (any dtype)
|
|
|
|
Returns:
|
|
Output tensor or None if weight is not FP8
|
|
"""
|
|
dtype = self.weight.dtype
|
|
if dtype not in [torch.float8_e4m3fn]:
|
|
return None
|
|
|
|
tensor_2d = False
|
|
if len(input.shape) == 2:
|
|
tensor_2d = True
|
|
input = input.unsqueeze(1)
|
|
|
|
input_shape = input.shape
|
|
input_dtype = input.dtype
|
|
|
|
if len(input.shape) == 3:
|
|
# Get weight and bias using standard casting
|
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
|
|
|
# Get scales (same as before)
|
|
scale_weight = self.scale_weight
|
|
scale_input = self.scale_input
|
|
if scale_weight is None:
|
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
|
else:
|
|
scale_weight = scale_weight.to(input.device)
|
|
|
|
if scale_input is None:
|
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
|
else:
|
|
scale_input = scale_input.to(input.device)
|
|
|
|
# Wrap weight in QuantizedTensorFP8 - this enables unified dispatch
|
|
quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype)
|
|
quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype)
|
|
# Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py!
|
|
# This is the key unification: all FP8 computation goes through one path
|
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
|
|
|
# Reshape output
|
|
if tensor_2d:
|
|
return o.reshape(input_shape[0], -1)
|
|
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
|
|
|
return None
|
|
|
|
class fp8_ops(manual_cast):
|
|
class Linear(manual_cast.Linear):
|
|
def reset_parameters(self):
|
|
self.scale_weight = None
|
|
self.scale_input = None
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
if not self.training:
|
|
try:
|
|
out = fp8_linear(self, input)
|
|
if out is not None:
|
|
return out
|
|
except Exception as e:
|
|
logging.info("Exception during fp8 op: {}".format(e))
|
|
|
|
weight, bias = cast_bias_weight(self, input)
|
|
return torch.nn.functional.linear(input, weight, bias)
|
|
|
|
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
|
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
|
class scaled_fp8_op(manual_cast):
|
|
class Linear(manual_cast.Linear):
|
|
def __init__(self, *args, **kwargs):
|
|
if override_dtype is not None:
|
|
kwargs['dtype'] = override_dtype
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def reset_parameters(self):
|
|
if not hasattr(self, 'scale_weight'):
|
|
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
|
|
|
if not scale_input:
|
|
self.scale_input = None
|
|
|
|
if not hasattr(self, 'scale_input'):
|
|
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
if fp8_matrix_mult:
|
|
out = fp8_linear(self, input)
|
|
if out is not None:
|
|
return out
|
|
|
|
weight, bias = cast_bias_weight(self, input)
|
|
|
|
if weight.numel() < input.numel(): #TODO: optimize
|
|
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
|
else:
|
|
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
|
|
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
|
if inplace:
|
|
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
|
return weight
|
|
else:
|
|
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
|
|
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
|
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
|
if return_weight:
|
|
return weight
|
|
if inplace_update:
|
|
self.weight.data.copy_(weight)
|
|
else:
|
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
|
|
return scaled_fp8_op
|
|
|
|
CUBLAS_IS_AVAILABLE = False
|
|
try:
|
|
from cublas_ops import CublasLinear
|
|
CUBLAS_IS_AVAILABLE = True
|
|
except ImportError:
|
|
pass
|
|
|
|
if CUBLAS_IS_AVAILABLE:
|
|
class cublas_ops(disable_weight_init):
|
|
class Linear(CublasLinear, disable_weight_init.Linear):
|
|
def reset_parameters(self):
|
|
return None
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
return super().forward(input)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return super().forward(*args, **kwargs)
|
|
|
|
|
|
# Import quantization operations from separate module
|
|
from .quant_ops import QuantizedTensorFP8
|
|
|
|
|
|
# ==============================================================================
|
|
# Mixed Precision Operations
|
|
# ==============================================================================
|
|
|
|
class MixedPrecisionOps(disable_weight_init):
|
|
"""
|
|
Operations class supporting per-layer quantization (mixed precision).
|
|
|
|
This class enables different layers to use different quantization formats
|
|
within the same model (e.g., some layers FP8, others BF16).
|
|
|
|
Layer-specific quantization is configured via _layer_quant_config class variable,
|
|
which is set by pick_operations() when a model has mixed precision.
|
|
"""
|
|
|
|
_layer_quant_config = {} # Class variable set by pick_operations()
|
|
|
|
class Linear(disable_weight_init.Linear):
|
|
"""Linear layer with optional per-layer quantization using tensor subclasses"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.quant_format = None
|
|
self.quant_scale = None
|
|
self._quantization_initialized = False
|
|
|
|
def reset_parameters(self):
|
|
# Don't allocate weights - return None like disable_weight_init
|
|
return None
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
|
strict, missing_keys, unexpected_keys, error_msgs):
|
|
"""
|
|
Called by PyTorch during load_state_dict.
|
|
Load weight and wrap in QuantizedTensorFP8 if this layer is quantized.
|
|
"""
|
|
# Call parent to load weight and bias first
|
|
super()._load_from_state_dict(
|
|
state_dict, prefix, local_metadata,
|
|
strict, missing_keys, unexpected_keys, error_msgs
|
|
)
|
|
|
|
# After weight is loaded, wrap it if this layer is quantized
|
|
if not self._quantization_initialized:
|
|
# Normalize layer name from prefix
|
|
layer_name = prefix.rstrip('.')
|
|
|
|
# Strip known model prefixes
|
|
for model_prefix in ["model.diffusion_model.", "model.model.", "net."]:
|
|
if layer_name.startswith(model_prefix):
|
|
layer_name = layer_name[len(model_prefix):]
|
|
break
|
|
|
|
# Check if this layer has quantization config
|
|
if layer_name in MixedPrecisionOps._layer_quant_config:
|
|
config = MixedPrecisionOps._layer_quant_config[layer_name]
|
|
self.quant_format = config.get("format", "fp8_e4m3fn")
|
|
|
|
# Load scale parameter
|
|
scale_key = f"{prefix}scale_weight"
|
|
if scale_key in state_dict:
|
|
self.quant_scale = state_dict[scale_key]
|
|
|
|
# Wrap weight in QuantizedTensorFP8
|
|
if self.weight is not None and self.weight.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
try:
|
|
# Determine original dtype (default to bfloat16)
|
|
orig_dtype = torch.bfloat16
|
|
|
|
# Wrap weight in quantized tensor subclass
|
|
quantized_weight = QuantizedTensorFP8(
|
|
self.weight.data,
|
|
self.quant_scale,
|
|
orig_dtype=orig_dtype
|
|
)
|
|
|
|
# Replace weight parameter with wrapped version
|
|
self.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
|
|
|
|
logging.debug(f"Wrapped layer {layer_name} weight in QuantizedTensorFP8 (format: {self.quant_format})")
|
|
except Exception as e:
|
|
logging.warning(f"Failed to wrap layer {layer_name} in QuantizedTensorFP8: {e}")
|
|
self.quant_format = None
|
|
self.quant_scale = None
|
|
else:
|
|
logging.debug(f"Layer {layer_name} has scale but weight dtype is not FP8, skipping quantization")
|
|
self.quant_format = None
|
|
self.quant_scale = None
|
|
else:
|
|
logging.debug(f"Layer {layer_name} has quant config but no scale_weight in state_dict")
|
|
self.quant_format = None
|
|
|
|
self._quantization_initialized = True
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
|
"""Save layer parameters including quantization scale"""
|
|
# First unwrap the weight if it's quantized
|
|
if isinstance(self.weight, torch.nn.Parameter) and isinstance(self.weight.data, QuantizedTensorFP8):
|
|
# Temporarily unwrap to save the raw FP8 data
|
|
quantized_tensor = self.weight.data
|
|
raw_fp8_data = quantized_tensor._raw_data
|
|
original_weight = self.weight
|
|
self.weight = torch.nn.Parameter(raw_fp8_data, requires_grad=False)
|
|
|
|
# Call parent to save unwrapped weight
|
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
|
|
|
# Restore the wrapped weight
|
|
self.weight = original_weight
|
|
|
|
# Save the scale parameter
|
|
if self.quant_scale is not None:
|
|
destination[f"{prefix}scale_weight"] = self.quant_scale if keep_vars else self.quant_scale.detach()
|
|
else:
|
|
# Standard path for non-quantized weights
|
|
super()._save_to_state_dict(destination, prefix, keep_vars)
|
|
|
|
def forward_comfy_cast_weights(self, input):
|
|
"""
|
|
Forward pass - tensor subclass handles dispatch automatically!
|
|
__torch_dispatch__ will route to registered handlers based on tensor types.
|
|
"""
|
|
weight, bias = cast_bias_weight(self, input)
|
|
|
|
# Call F.linear - if weight is QuantizedTensorFP8, __torch_dispatch__ handles it!
|
|
return torch.nn.functional.linear(input, weight, bias)
|
|
|
|
def forward(self, *args, **kwargs):
|
|
"""Main forward pass"""
|
|
run_every_op()
|
|
# Same logic as disable_weight_init.Linear
|
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
else:
|
|
return super().forward(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def conv_nd(s, dims, *args, **kwargs):
|
|
"""Create Conv layer (same as disable_weight_init)"""
|
|
if dims == 2:
|
|
return s.Conv2d(*args, **kwargs)
|
|
elif dims == 3:
|
|
return s.Conv3d(*args, **kwargs)
|
|
else:
|
|
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
|
"""
|
|
Select appropriate operations class for model.
|
|
|
|
NEW: If model_config.layer_quant_config exists, returns MixedPrecisionOps (Phase 3).
|
|
LEGACY: All other paths unchanged for backward compatibility.
|
|
|
|
Args:
|
|
weight_dtype: Weight storage dtype
|
|
compute_dtype: Computation dtype
|
|
load_device: Device for loading
|
|
disable_fast_fp8: Disable fast FP8 paths
|
|
fp8_optimizations: Enable FP8 optimizations
|
|
scaled_fp8: Legacy FP8 dtype marker
|
|
model_config: Model config object (optional, for mixed precision support)
|
|
|
|
Returns:
|
|
Operations class (e.g., MixedPrecisionOps, fp8_ops, disable_weight_init)
|
|
"""
|
|
# NEW: Check for mixed precision
|
|
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
|
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
|
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
|
return MixedPrecisionOps
|
|
|
|
# LEGACY paths (unchanged)
|
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
|
if scaled_fp8 is not None:
|
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
|
|
|
if (
|
|
fp8_compute and
|
|
(fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
|
|
not disable_fast_fp8
|
|
):
|
|
return fp8_ops
|
|
|
|
if (
|
|
PerformanceFeature.CublasOps in args.fast and
|
|
CUBLAS_IS_AVAILABLE and
|
|
weight_dtype == torch.float16 and
|
|
(compute_dtype == torch.float16 or compute_dtype is None)
|
|
):
|
|
logging.info("Using cublas ops")
|
|
return cublas_ops
|
|
|
|
if compute_dtype is None or weight_dtype == compute_dtype:
|
|
return disable_weight_init
|
|
|
|
return manual_cast
|