mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
Refactor mixed precision ops to share more code
This commit is contained in:
parent
0b0f1b1cf6
commit
97c8bdb781
599
comfy/ops.py
599
comfy/ops.py
@ -18,6 +18,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
import contextlib
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy.float
|
import comfy.float
|
||||||
@ -1047,6 +1048,144 @@ class QuantLinearFunc(torch.autograd.Function):
|
|||||||
|
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None, None
|
||||||
|
|
||||||
|
# Quantized-weight module helpers
|
||||||
|
|
||||||
|
def _quantized_apply(module, fn, recurse=True):
|
||||||
|
"""Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights."""
|
||||||
|
if recurse:
|
||||||
|
for child in module.children():
|
||||||
|
child._apply(fn)
|
||||||
|
for key, param in module._parameters.items():
|
||||||
|
if param is None:
|
||||||
|
continue
|
||||||
|
p = fn(param)
|
||||||
|
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
||||||
|
p = p.clone()
|
||||||
|
module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||||
|
for key, buf in module._buffers.items():
|
||||||
|
if buf is not None:
|
||||||
|
module._buffers[key] = fn(buf)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict,
|
||||||
|
missing_keys, unexpected_keys, error_msgs, load_extra_params=False):
|
||||||
|
"""Shared _load_from_state_dict body for quantized-weight modules.
|
||||||
|
|
||||||
|
Pops weight (+ scales, +/- extras), populates module.weight as a Parameter
|
||||||
|
or Parameter-wrapped QuantizedTensor, then calls super_load and strips
|
||||||
|
consumed keys from missing_keys. Reads compute_dtype from factory_kwargs
|
||||||
|
and disabled formats from module._disabled_formats.
|
||||||
|
"""
|
||||||
|
device = module.factory_kwargs["device"]
|
||||||
|
compute_dtype = module.factory_kwargs["dtype"]
|
||||||
|
disabled_formats = module._disabled_formats
|
||||||
|
layer_name = prefix.rstrip('.')
|
||||||
|
|
||||||
|
weight = state_dict.pop(f"{prefix}weight", None)
|
||||||
|
if weight is None:
|
||||||
|
logging.warning(f"Missing weight for layer {layer_name}")
|
||||||
|
module.weight = None
|
||||||
|
return
|
||||||
|
manually_loaded_keys = [f"{prefix}weight"]
|
||||||
|
|
||||||
|
def pop_scale(name, dtype=None):
|
||||||
|
key = f"{prefix}{name}"
|
||||||
|
v = state_dict.pop(key, None)
|
||||||
|
if v is not None:
|
||||||
|
v = v.to(device=device)
|
||||||
|
if dtype is not None:
|
||||||
|
v = v.view(dtype=dtype)
|
||||||
|
manually_loaded_keys.append(key)
|
||||||
|
return v
|
||||||
|
|
||||||
|
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||||
|
if layer_conf is not None:
|
||||||
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||||
|
|
||||||
|
if layer_conf is None:
|
||||||
|
module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False)
|
||||||
|
else:
|
||||||
|
module.quant_format = layer_conf.get("format", None)
|
||||||
|
module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||||
|
if not module._full_precision_mm:
|
||||||
|
module._full_precision_mm = module._full_precision_mm_config
|
||||||
|
if module.quant_format in disabled_formats:
|
||||||
|
module._full_precision_mm = True
|
||||||
|
if module.quant_format is None:
|
||||||
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
|
qconfig = QUANT_ALGOS[module.quant_format]
|
||||||
|
module.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
|
layout_cls = get_layout_class(module.layout_type)
|
||||||
|
|
||||||
|
# Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8.
|
||||||
|
if module.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||||
|
scales = {"scale": pop_scale("weight_scale")}
|
||||||
|
elif module.quant_format == "mxfp8":
|
||||||
|
bs = pop_scale("weight_scale", torch.float8_e8m0fnu)
|
||||||
|
if bs is None:
|
||||||
|
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||||
|
scales = {"scale": bs}
|
||||||
|
elif module.quant_format == "nvfp4":
|
||||||
|
ts = pop_scale("weight_scale_2")
|
||||||
|
bs = pop_scale("weight_scale", torch.float8_e4m3fn)
|
||||||
|
if ts is None or bs is None:
|
||||||
|
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||||
|
scales = {"scale": ts, "block_scale": bs}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
|
||||||
|
|
||||||
|
params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape)
|
||||||
|
module.weight = torch.nn.Parameter(
|
||||||
|
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if load_extra_params:
|
||||||
|
for param_name in qconfig["parameters"]:
|
||||||
|
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||||
|
continue
|
||||||
|
param_key = f"{prefix}{param_name}"
|
||||||
|
_v = state_dict.pop(param_key, None)
|
||||||
|
if _v is None:
|
||||||
|
continue
|
||||||
|
module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||||
|
manually_loaded_keys.append(param_key)
|
||||||
|
|
||||||
|
super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
for key in manually_loaded_keys:
|
||||||
|
if key in missing_keys:
|
||||||
|
missing_keys.remove(key)
|
||||||
|
|
||||||
|
|
||||||
|
def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()):
|
||||||
|
"""Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON;
|
||||||
|
extra_quant_params names attributes written as additional top-level keys."""
|
||||||
|
if not hasattr(module, 'weight'):
|
||||||
|
logging.warning(f"Warning: state dict on uninitialized op {prefix}")
|
||||||
|
return sd
|
||||||
|
bias = getattr(module, 'bias', None)
|
||||||
|
if bias is not None:
|
||||||
|
sd[f"{prefix}bias"] = bias
|
||||||
|
if module.weight is None:
|
||||||
|
return sd
|
||||||
|
if isinstance(module.weight, QuantizedTensor):
|
||||||
|
sd.update(module.weight.state_dict(f"{prefix}weight"))
|
||||||
|
quant_conf = {"format": module.quant_format}
|
||||||
|
if getattr(module, '_full_precision_mm_config', False):
|
||||||
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
|
if extra_quant_conf:
|
||||||
|
quant_conf.update(extra_quant_conf)
|
||||||
|
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
|
||||||
|
for name in extra_quant_params:
|
||||||
|
value = getattr(module, name, None)
|
||||||
|
if value is not None:
|
||||||
|
sd[f"{prefix}{name}"] = value
|
||||||
|
else:
|
||||||
|
sd[f"{prefix}weight"] = module.weight
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||||
class MixedPrecisionOps(manual_cast):
|
class MixedPrecisionOps(manual_cast):
|
||||||
@ -1056,21 +1195,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
_disabled = disabled
|
_disabled = disabled
|
||||||
|
|
||||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||||
def __init__(
|
_disabled_formats = disabled
|
||||||
self,
|
|
||||||
in_features: int,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||||
out_features: int,
|
|
||||||
bias: bool = True,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
|
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
|
self._orig_shape = (out_features, in_features)
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||||
else:
|
else:
|
||||||
@ -1083,151 +1217,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
|
def _load_from_state_dict(self, *args):
|
||||||
key = f"{prefix}{param_name}"
|
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True)
|
||||||
value = state_dict.pop(key, None)
|
|
||||||
if value is not None:
|
|
||||||
value = value.to(device=device)
|
|
||||||
if dtype is not None:
|
|
||||||
value = value.view(dtype=dtype)
|
|
||||||
manually_loaded_keys.append(key)
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
|
||||||
|
|
||||||
device = self.factory_kwargs["device"]
|
|
||||||
layer_name = prefix.rstrip('.')
|
|
||||||
weight_key = f"{prefix}weight"
|
|
||||||
weight = state_dict.pop(weight_key, None)
|
|
||||||
if weight is None:
|
|
||||||
logging.warning(f"Missing weight for layer {layer_name}")
|
|
||||||
self.weight = None
|
|
||||||
return
|
|
||||||
|
|
||||||
manually_loaded_keys = [weight_key]
|
|
||||||
|
|
||||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
|
||||||
if layer_conf is not None:
|
|
||||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
|
||||||
|
|
||||||
if layer_conf is None:
|
|
||||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
|
||||||
else:
|
|
||||||
self.quant_format = layer_conf.get("format", None)
|
|
||||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
|
||||||
if not self._full_precision_mm:
|
|
||||||
self._full_precision_mm = self._full_precision_mm_config
|
|
||||||
|
|
||||||
if self.quant_format in MixedPrecisionOps._disabled:
|
|
||||||
self._full_precision_mm = True
|
|
||||||
|
|
||||||
if self.quant_format is None:
|
|
||||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
|
||||||
|
|
||||||
qconfig = QUANT_ALGOS[self.quant_format]
|
|
||||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
|
||||||
layout_cls = get_layout_class(self.layout_type)
|
|
||||||
|
|
||||||
# Load format-specific parameters
|
|
||||||
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
|
||||||
# FP8: single tensor scale
|
|
||||||
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
|
||||||
|
|
||||||
params = layout_cls.Params(
|
|
||||||
scale=scale,
|
|
||||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
|
||||||
orig_shape=(self.out_features, self.in_features),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self.quant_format == "mxfp8":
|
|
||||||
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
|
||||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
|
||||||
dtype=torch.uint8)
|
|
||||||
|
|
||||||
if block_scale is None:
|
|
||||||
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
|
||||||
|
|
||||||
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
|
||||||
|
|
||||||
params = layout_cls.Params(
|
|
||||||
scale=block_scale,
|
|
||||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
|
||||||
orig_shape=(self.out_features, self.in_features),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif self.quant_format == "nvfp4":
|
|
||||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
|
||||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
|
||||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
|
||||||
dtype=torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
if tensor_scale is None or block_scale is None:
|
|
||||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
|
||||||
|
|
||||||
params = layout_cls.Params(
|
|
||||||
scale=tensor_scale,
|
|
||||||
block_scale=block_scale,
|
|
||||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
|
||||||
orig_shape=(self.out_features, self.in_features),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
|
||||||
|
|
||||||
self.weight = torch.nn.Parameter(
|
|
||||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
|
||||||
requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
for param_name in qconfig["parameters"]:
|
|
||||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
|
||||||
continue # Already handled above
|
|
||||||
|
|
||||||
param_key = f"{prefix}{param_name}"
|
|
||||||
_v = state_dict.pop(param_key, None)
|
|
||||||
if _v is None:
|
|
||||||
continue
|
|
||||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
|
||||||
manually_loaded_keys.append(param_key)
|
|
||||||
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
||||||
|
|
||||||
for key in manually_loaded_keys:
|
|
||||||
if key in missing_keys:
|
|
||||||
missing_keys.remove(key)
|
|
||||||
|
|
||||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
if destination is not None:
|
sd = destination if destination is not None else {}
|
||||||
sd = destination
|
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",))
|
||||||
else:
|
|
||||||
sd = {}
|
|
||||||
|
|
||||||
if not hasattr(self, 'weight'):
|
|
||||||
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
|
|
||||||
return sd
|
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
sd["{}bias".format(prefix)] = self.bias
|
|
||||||
|
|
||||||
if self.weight is None:
|
|
||||||
return sd
|
|
||||||
|
|
||||||
if isinstance(self.weight, QuantizedTensor):
|
|
||||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
|
||||||
for k in sd_out:
|
|
||||||
sd[k] = sd_out[k]
|
|
||||||
|
|
||||||
quant_conf = {"format": self.quant_format}
|
|
||||||
if self._full_precision_mm_config:
|
|
||||||
quant_conf["full_precision_matrix_mult"] = True
|
|
||||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
|
||||||
|
|
||||||
input_scale = getattr(self, 'input_scale', None)
|
|
||||||
if input_scale is not None:
|
|
||||||
sd["{}input_scale".format(prefix)] = input_scale
|
|
||||||
else:
|
|
||||||
sd["{}weight".format(prefix)] = self.weight
|
|
||||||
return sd
|
|
||||||
|
|
||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
@ -1317,46 +1312,34 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||||
|
|
||||||
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
||||||
if recurse:
|
return _quantized_apply(self, fn, recurse)
|
||||||
for module in self.children():
|
|
||||||
module._apply(fn)
|
|
||||||
|
|
||||||
for key, param in self._parameters.items():
|
class MoEExperts(torch.nn.Module, CastWeightBiasOp):
|
||||||
if param is None:
|
"""Container for E quantized expert weights, indexed via expert_weight(i).
|
||||||
continue
|
|
||||||
p = fn(param)
|
|
||||||
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
|
||||||
p = p.clone()
|
|
||||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
|
||||||
for key, buf in self._buffers.items():
|
|
||||||
if buf is not None:
|
|
||||||
self._buffers[key] = fn(buf)
|
|
||||||
return self
|
|
||||||
|
|
||||||
class MoEExperts(CastWeightBiasOp, torch.nn.Module):
|
The bank lives on self.weight as a single 3D tensor — either a
|
||||||
"""Container for E quantized expert weights, indexed via ``expert_weight(i)``.
|
compute_dtype Parameter or a Parameter wrapping a QuantizedTensor
|
||||||
|
|
||||||
The full bank lives on ``self.weight`` as a single (3D) tensor — either
|
|
||||||
a bf16 ``Parameter`` or a ``Parameter`` wrapping a ``QuantizedTensor``
|
|
||||||
with leading expert dim.
|
with leading expert dim.
|
||||||
|
|
||||||
State-dict layout (analogous to ``mixed_precision_ops.Linear`` with a
|
State-dict layout matches mixed_precision_ops.Linear with a leading
|
||||||
leading expert dim — exact storage shape is layout-specific)::
|
expert dim:
|
||||||
|
|
||||||
{prefix}.weight quant data (storage_t), leading dim = E
|
{prefix}.weight quant data (storage_t), leading dim = E
|
||||||
{prefix}.weight_scale block / per-tensor scale
|
{prefix}.weight_scale block / per-tensor scale
|
||||||
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
|
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
|
||||||
{prefix}.bias [E, out_features] optional, bf16
|
{prefix}.bias [E, out_features] optional, compute_dtype
|
||||||
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
|
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
|
||||||
|
|
||||||
Without ``comfy_quant`` the weight loads as a plain bf16 3D Parameter ``[E, out, in]``.
|
Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_disabled_formats = disabled
|
||||||
|
|
||||||
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
|
self._orig_shape = (num_experts, out_features, in_features)
|
||||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
|
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
|
||||||
@ -1369,119 +1352,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
self.layout_type = None
|
self.layout_type = None
|
||||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||||
self._full_precision_mm_config = False
|
self._full_precision_mm_config = False
|
||||||
|
self._resident_bank = None
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
# No-op so module init doesn't clobber the loaded quant weights.
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _apply(self, fn, recurse=True):
|
def _apply(self, fn, recurse=True):
|
||||||
# Mirror Linear._apply: re-wrap each Parameter so .to()/.cuda()
|
return _quantized_apply(self, fn, recurse)
|
||||||
# propagate through the QuantizedTensor wrapped inside self.weight.
|
|
||||||
if recurse:
|
|
||||||
for module in self.children():
|
|
||||||
module._apply(fn)
|
|
||||||
for key, param in self._parameters.items():
|
|
||||||
if param is None:
|
|
||||||
continue
|
|
||||||
p = fn(param)
|
|
||||||
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
|
||||||
p = p.clone()
|
|
||||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
|
||||||
for key, buf in self._buffers.items():
|
|
||||||
if buf is not None:
|
|
||||||
self._buffers[key] = fn(buf)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def _load_scale_param(self, state_dict, prefix, param_name, device,
|
def _load_from_state_dict(self, *args):
|
||||||
manually_loaded_keys, dtype=None):
|
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False)
|
||||||
key = f"{prefix}{param_name}"
|
|
||||||
value = state_dict.pop(key, None)
|
|
||||||
if value is not None:
|
|
||||||
value = value.to(device=device)
|
|
||||||
if dtype is not None:
|
|
||||||
value = value.view(dtype=dtype)
|
|
||||||
manually_loaded_keys.append(key)
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
|
||||||
device = self.factory_kwargs["device"]
|
|
||||||
layer_name = prefix.rstrip(".")
|
|
||||||
weight_key = f"{prefix}weight"
|
|
||||||
weight = state_dict.pop(weight_key, None)
|
|
||||||
if weight is None:
|
|
||||||
logging.warning(f"Missing weight for MoEExperts layer {layer_name}")
|
|
||||||
return
|
|
||||||
manually_loaded_keys = [weight_key]
|
|
||||||
|
|
||||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
|
||||||
if layer_conf is not None:
|
|
||||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
|
||||||
manually_loaded_keys.append(f"{prefix}comfy_quant")
|
|
||||||
|
|
||||||
if layer_conf is None:
|
|
||||||
self.weight = torch.nn.Parameter(
|
|
||||||
weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.quant_format = layer_conf.get("format")
|
|
||||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
|
||||||
if not self._full_precision_mm:
|
|
||||||
self._full_precision_mm = self._full_precision_mm_config
|
|
||||||
|
|
||||||
if self.quant_format in MixedPrecisionOps._disabled:
|
|
||||||
self._full_precision_mm = True
|
|
||||||
|
|
||||||
if self.quant_format is None:
|
|
||||||
raise ValueError(f"Unknown quant format for MoEExperts layer {layer_name}")
|
|
||||||
|
|
||||||
qconfig = QUANT_ALGOS[self.quant_format]
|
|
||||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
|
||||||
layout_cls = get_layout_class(self.layout_type)
|
|
||||||
orig_shape = (self.num_experts, self.out_features, self.in_features)
|
|
||||||
|
|
||||||
# Scales keep their leading expert dim; per-expert slicing happens at access.
|
|
||||||
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
|
||||||
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
|
||||||
params = layout_cls.Params(
|
|
||||||
scale=scale,
|
|
||||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
|
||||||
orig_shape=orig_shape,
|
|
||||||
)
|
|
||||||
elif self.quant_format == "mxfp8":
|
|
||||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
|
||||||
if block_scale is None:
|
|
||||||
raise ValueError(f"Missing MXFP8 block scales for MoEExperts layer {layer_name}")
|
|
||||||
params = layout_cls.Params(
|
|
||||||
scale=block_scale,
|
|
||||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
|
||||||
orig_shape=orig_shape,
|
|
||||||
)
|
|
||||||
elif self.quant_format == "nvfp4":
|
|
||||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
|
||||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
|
||||||
if tensor_scale is None or block_scale is None:
|
|
||||||
raise ValueError(f"Missing NVFP4 scales for MoEExperts layer {layer_name}")
|
|
||||||
params = layout_cls.Params(
|
|
||||||
scale=tensor_scale,
|
|
||||||
block_scale=block_scale,
|
|
||||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
|
||||||
orig_shape=orig_shape,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported MoEExperts quant format: {self.quant_format}")
|
|
||||||
|
|
||||||
qdata = weight.to(device=device, dtype=qconfig["storage_t"])
|
|
||||||
self.weight = torch.nn.Parameter(
|
|
||||||
QuantizedTensor(qdata, self.layout_type, params),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
|
||||||
for k in manually_loaded_keys:
|
|
||||||
if k in missing_keys:
|
|
||||||
missing_keys.remove(k)
|
|
||||||
|
|
||||||
def expert_weight(self, i: int):
|
def expert_weight(self, i: int):
|
||||||
"""Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
|
"""Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
|
||||||
@ -1489,76 +1369,69 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
return self._expert_qt_from(self.weight, i)
|
return self._expert_qt_from(self.weight, i)
|
||||||
return self.weight[i]
|
return self.weight[i]
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def bank_resident(self, input):
|
||||||
|
"""Cast the whole bank once; expert_linear inside reuses the cast.
|
||||||
|
Not re-entrant — do not nest calls on the same instance.
|
||||||
|
"""
|
||||||
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
|
self._resident_bank = (weight, bias)
|
||||||
|
try:
|
||||||
|
yield self
|
||||||
|
finally:
|
||||||
|
self._resident_bank = None
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
|
||||||
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
|
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
|
||||||
"""Linear against expert ``i``'s weight (with optional bias)."""
|
"""Linear against expert i's weight (with optional bias)."""
|
||||||
|
resident = getattr(self, "_resident_bank", None)
|
||||||
|
if resident is not None:
|
||||||
|
weight, bias = resident
|
||||||
|
return self._expert_linear_impl(input, weight, bias, i)
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
try:
|
try:
|
||||||
if isinstance(weight, QuantizedTensor):
|
return self._expert_linear_impl(input, weight, bias, i)
|
||||||
qw = self._expert_qt_from(weight, i)
|
|
||||||
else:
|
|
||||||
qw = weight[i]
|
|
||||||
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
|
|
||||||
|
|
||||||
if isinstance(qw, QuantizedTensor):
|
|
||||||
use_fast = (
|
|
||||||
not self._full_precision_mm
|
|
||||||
and qw.layout_cls.supports_fast_matmul()
|
|
||||||
and input.dim() == 2
|
|
||||||
)
|
|
||||||
if use_fast:
|
|
||||||
qin = QuantizedTensor.from_float(input, self.layout_type)
|
|
||||||
return torch.nn.functional.linear(qin, qw, b)
|
|
||||||
out = input @ qw.dequantize().t()
|
|
||||||
return out + b if b is not None else out
|
|
||||||
return torch.nn.functional.linear(input, qw, b)
|
|
||||||
finally:
|
finally:
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
|
||||||
def _expert_qt_from(self, weight: "QuantizedTensor", i: int) -> "QuantizedTensor":
|
def _expert_linear_impl(self, input, weight, bias, i):
|
||||||
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
|
if isinstance(weight, QuantizedTensor):
|
||||||
qdata = weight._qdata[i]
|
qw = self._expert_qt_from(weight, i)
|
||||||
params = weight._params
|
|
||||||
orig_shape = (self.out_features, self.in_features)
|
|
||||||
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
|
||||||
scale = params.scale[i] if params.scale.dim() else params.scale
|
|
||||||
per_expert_params = type(params)(
|
|
||||||
scale=scale, orig_dtype=params.orig_dtype, orig_shape=orig_shape,
|
|
||||||
)
|
|
||||||
elif self.quant_format == "mxfp8":
|
|
||||||
per_expert_params = type(params)(
|
|
||||||
scale=params.scale[i], orig_dtype=params.orig_dtype, orig_shape=orig_shape,
|
|
||||||
)
|
|
||||||
elif self.quant_format == "nvfp4":
|
|
||||||
scale = params.scale[i] if params.scale.dim() else params.scale
|
|
||||||
per_expert_params = type(params)(
|
|
||||||
scale=scale, block_scale=params.block_scale[i],
|
|
||||||
orig_dtype=params.orig_dtype, orig_shape=orig_shape,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported quant format: {self.quant_format}")
|
qw = weight[i]
|
||||||
return QuantizedTensor(qdata, weight._layout_cls, per_expert_params)
|
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
|
||||||
|
|
||||||
|
if isinstance(qw, QuantizedTensor):
|
||||||
|
use_fast = (
|
||||||
|
not self._full_precision_mm
|
||||||
|
and qw.layout_cls.supports_fast_matmul()
|
||||||
|
and input.dim() == 2
|
||||||
|
)
|
||||||
|
if use_fast:
|
||||||
|
qin = QuantizedTensor.from_float(input, self.layout_type)
|
||||||
|
return torch.nn.functional.linear(qin, qw, b)
|
||||||
|
out = input @ qw.dequantize().t()
|
||||||
|
return out + b if b is not None else out
|
||||||
|
return torch.nn.functional.linear(input, qw, b)
|
||||||
|
|
||||||
|
def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor:
|
||||||
|
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
|
||||||
|
params = weight._params
|
||||||
|
kwargs = {
|
||||||
|
"scale": params.scale[i] if params.scale.dim() else params.scale,
|
||||||
|
"orig_dtype": params.orig_dtype,
|
||||||
|
"orig_shape": (self.out_features, self.in_features),
|
||||||
|
}
|
||||||
|
if hasattr(params, "block_scale"): # NVFP4
|
||||||
|
kwargs["block_scale"] = params.block_scale[i]
|
||||||
|
return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs))
|
||||||
|
|
||||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
sd = destination if destination is not None else {}
|
sd = destination if destination is not None else {}
|
||||||
if self.bias is not None:
|
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts})
|
||||||
sd[f"{prefix}bias"] = self.bias
|
|
||||||
if self.weight is None:
|
|
||||||
return sd
|
|
||||||
if isinstance(self.weight, QuantizedTensor):
|
|
||||||
sd.update(self.weight.state_dict(f"{prefix}weight"))
|
|
||||||
quant_conf = {"format": self.quant_format, "num_experts": self.num_experts}
|
|
||||||
if self._full_precision_mm_config:
|
|
||||||
quant_conf["full_precision_matrix_mult"] = True
|
|
||||||
sd[f"{prefix}comfy_quant"] = torch.tensor(
|
|
||||||
list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sd[f"{prefix}weight"] = self.weight
|
|
||||||
return sd
|
|
||||||
|
|
||||||
class Embedding(manual_cast.Embedding):
|
class Embedding(manual_cast.Embedding):
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
|
||||||
weight_key = f"{prefix}weight"
|
weight_key = f"{prefix}weight"
|
||||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||||
if layer_conf is not None:
|
if layer_conf is not None:
|
||||||
@ -1566,14 +1439,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
||||||
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
||||||
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
|
quant_format = layer_conf.get("format") if layer_conf is not None else None
|
||||||
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
|
manually_loaded_keys = []
|
||||||
|
|
||||||
|
if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict:
|
||||||
self.quant_format = quant_format
|
self.quant_format = quant_format
|
||||||
qconfig = QUANT_ALGOS[quant_format]
|
qconfig = QUANT_ALGOS[quant_format]
|
||||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
layout_cls = get_layout_class(self.layout_type)
|
layout_cls = get_layout_class(self.layout_type)
|
||||||
weight = state_dict.pop(weight_key)
|
weight = state_dict.pop(weight_key)
|
||||||
manually_loaded_keys = [weight_key]
|
manually_loaded_keys.append(weight_key)
|
||||||
|
|
||||||
scale_key = f"{prefix}weight_scale"
|
scale_key = f"{prefix}weight_scale"
|
||||||
scale = state_dict.pop(scale_key, None)
|
scale = state_dict.pop(scale_key, None)
|
||||||
@ -1589,35 +1464,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
elif layer_conf is not None:
|
||||||
|
# Unsupported format — restore the marker so it round-trips; fall through to default load.
|
||||||
|
state_dict[f"{prefix}comfy_quant"] = torch.tensor(
|
||||||
|
list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
for k in manually_loaded_keys:
|
for k in manually_loaded_keys:
|
||||||
if k in missing_keys:
|
if k in missing_keys:
|
||||||
missing_keys.remove(k)
|
missing_keys.remove(k)
|
||||||
else:
|
|
||||||
if layer_conf is not None:
|
|
||||||
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
|
||||||
|
|
||||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
if destination is not None:
|
sd = destination if destination is not None else {}
|
||||||
sd = destination
|
return _quantized_weight_state_dict(self, sd, prefix)
|
||||||
else:
|
|
||||||
sd = {}
|
|
||||||
|
|
||||||
if not hasattr(self, 'weight') or self.weight is None:
|
|
||||||
return sd
|
|
||||||
|
|
||||||
if isinstance(self.weight, QuantizedTensor):
|
|
||||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
|
||||||
for k in sd_out:
|
|
||||||
sd[k] = sd_out[k]
|
|
||||||
|
|
||||||
quant_conf = {"format": self.quant_format}
|
|
||||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
|
||||||
else:
|
|
||||||
sd["{}weight".format(prefix)] = self.weight
|
|
||||||
return sd
|
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
|
|||||||
@ -218,19 +218,21 @@ class GptOssExperts(nn.Module):
|
|||||||
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
|
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
|
||||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||||
|
|
||||||
for ei in expert_hit:
|
with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \
|
||||||
expert_idx = int(ei.item())
|
self.down_proj.bank_resident(hidden_states) as down_bank:
|
||||||
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
for ei in expert_hit:
|
||||||
current = hidden_states[token_idx]
|
expert_idx = int(ei.item())
|
||||||
|
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
||||||
|
current = hidden_states[token_idx]
|
||||||
|
|
||||||
gate_up = self.gate_up_proj.expert_linear(current, expert_idx)
|
gate_up = gate_up_bank.expert_linear(current, expert_idx)
|
||||||
gated = self._apply_gate(gate_up)
|
gated = self._apply_gate(gate_up)
|
||||||
expert_out = self.down_proj.expert_linear(gated, expert_idx)
|
expert_out = down_bank.expert_linear(gated, expert_idx)
|
||||||
|
|
||||||
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
|
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
|
||||||
|
|
||||||
flat_idx = token_idx * top_k + top_k_pos
|
flat_idx = token_idx * top_k + top_k_pos
|
||||||
per_pair[flat_idx] = weighted.to(per_pair.dtype)
|
per_pair[flat_idx] = weighted.to(per_pair.dtype)
|
||||||
|
|
||||||
return per_pair.view(N, top_k, H).sum(dim=1)
|
return per_pair.view(N, top_k, H).sum(dim=1)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user