Refactor MoEExperts some

This commit is contained in:
kijai 2026-05-24 03:48:47 +03:00
parent c431fd555b
commit 1d48277a1c

View File

@ -1333,10 +1333,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self._buffers[key] = fn(buf)
return self
class MoEExperts(torch.nn.Module):
class MoEExperts(CastWeightBiasOp, torch.nn.Module):
"""Container for E quantized expert weights, indexed via ``expert_weight(i)``.
Holds expert weights as 3D buffers/parameters.
The full bank lives on ``self.weight`` as a single (3D) tensor either
a bf16 ``Parameter`` or a ``Parameter`` wrapping a ``QuantizedTensor``
with leading expert dim.
State-dict layout (analogous to ``mixed_precision_ops.Linear`` with a
leading expert dim exact storage shape is layout-specific)::
@ -1362,15 +1364,34 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.register_parameter("bias", None)
# Populated by _load_from_state_dict:
self.weight = None # bf16 fallback: 3D Parameter [E, out, in]
self.weight = None
self.quant_format = None
self.layout_type = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
self._full_precision_mm_config = False
def reset_parameters(self):
# No-op so module init doesn't clobber the loaded quant weights.
return None
def _apply(self, fn, recurse=True):
# Mirror Linear._apply: re-wrap each Parameter so .to()/.cuda()
# propagate through the QuantizedTensor wrapped inside self.weight.
if recurse:
for module in self.children():
module._apply(fn)
for key, param in self._parameters.items():
if param is None:
continue
p = fn(param)
if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
def _load_scale_param(self, state_dict, prefix, param_name, device,
manually_loaded_keys, dtype=None):
key = f"{prefix}{param_name}"
@ -1382,7 +1403,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
manually_loaded_keys.append(key)
return value
# TODO: refactor to share more code with Linear._load_from_state_dict
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip(".")
@ -1417,31 +1437,44 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
orig_shape = (self.num_experts, self.out_features, self.in_features)
# Scales keep their leading expert dim; per-expert slicing happens at access.
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
ts = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
self.register_buffer("_tensor_scale", ts, persistent=False)
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
params = layout_cls.Params(
scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
elif self.quant_format == "mxfp8":
bs = self._load_scale_param(state_dict, prefix, "weight_scale", device,
manually_loaded_keys, dtype=torch.uint8)
if bs is None:
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for MoEExperts layer {layer_name}")
self.register_buffer("_block_scale", bs.view(torch.float8_e8m0fnu), persistent=False)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
elif self.quant_format == "nvfp4":
ts = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
bs = self._load_scale_param(state_dict, prefix, "weight_scale", device,
manually_loaded_keys, dtype=torch.float8_e4m3fn)
if ts is None or bs is None:
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
if tensor_scale is None or block_scale is None:
raise ValueError(f"Missing NVFP4 scales for MoEExperts layer {layer_name}")
self.register_buffer("_tensor_scale", ts, persistent=False)
self.register_buffer("_block_scale", bs, persistent=False)
params = layout_cls.Params(
scale=tensor_scale,
block_scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
else:
raise ValueError(f"Unsupported MoEExperts quant format: {self.quant_format}")
self.register_buffer(
"_qdata",
weight.to(device=device, dtype=qconfig["storage_t"]),
persistent=False,
qdata = weight.to(device=device, dtype=qconfig["storage_t"])
self.weight = torch.nn.Parameter(
QuantizedTensor(qdata, self.layout_type, params),
requires_grad=False,
)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
@ -1451,84 +1484,76 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
missing_keys.remove(k)
def expert_weight(self, i: int):
"""Expert i's weight (Tensor or QuantizedTensor)."""
if self.quant_format is None:
return self.weight[i]
qdata = self._qdata[i]
layout_cls = get_layout_class(self.layout_type)
orig_shape = (self.out_features, self.in_features)
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
scale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale
params = layout_cls.Params(
scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
elif self.quant_format == "mxfp8":
params = layout_cls.Params(
scale=self._block_scale[i],
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
elif self.quant_format == "nvfp4":
tscale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale
params = layout_cls.Params(
scale=tscale,
block_scale=self._block_scale[i],
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
else:
raise ValueError(f"Unsupported quant format: {self.quant_format}")
return QuantizedTensor(qdata, self.layout_type, params)
"""Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
if isinstance(self.weight, QuantizedTensor):
return self._expert_qt_from(self.weight, i)
return self.weight[i]
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
"""Linear against expert ``i``'s weight (with optional bias)."""
qw = self.expert_weight(i)
bias = None
if self.bias is not None:
bias = cast_to_input(self.bias[i], input, copy=False)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
try:
if isinstance(weight, QuantizedTensor):
qw = self._expert_qt_from(weight, i)
else:
qw = weight[i]
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
if isinstance(qw, QuantizedTensor):
use_fast = (
not self._full_precision_mm
and qw.layout_cls.supports_fast_matmul()
and input.dim() == 2
if isinstance(qw, QuantizedTensor):
use_fast = (
not self._full_precision_mm
and qw.layout_cls.supports_fast_matmul()
and input.dim() == 2
)
if use_fast:
qin = QuantizedTensor.from_float(input, self.layout_type)
return torch.nn.functional.linear(qin, qw, b)
out = input @ qw.dequantize().t()
return out + b if b is not None else out
return torch.nn.functional.linear(input, qw, b)
finally:
uncast_bias_weight(self, weight, bias, offload_stream)
def _expert_qt_from(self, weight: "QuantizedTensor", i: int) -> "QuantizedTensor":
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
qdata = weight._qdata[i]
params = weight._params
orig_shape = (self.out_features, self.in_features)
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
scale = params.scale[i] if params.scale.dim() else params.scale
per_expert_params = type(params)(
scale=scale, orig_dtype=params.orig_dtype, orig_shape=orig_shape,
)
if use_fast:
qin = QuantizedTensor.from_float(input, self.layout_type)
return torch.nn.functional.linear(qin, qw, bias)
out = input @ qw.dequantize().t()
return out + bias if bias is not None else out
return torch.nn.functional.linear(input, qw, bias)
elif self.quant_format == "mxfp8":
per_expert_params = type(params)(
scale=params.scale[i], orig_dtype=params.orig_dtype, orig_shape=orig_shape,
)
elif self.quant_format == "nvfp4":
scale = params.scale[i] if params.scale.dim() else params.scale
per_expert_params = type(params)(
scale=scale, block_scale=params.block_scale[i],
orig_dtype=params.orig_dtype, orig_shape=orig_shape,
)
else:
raise ValueError(f"Unsupported quant format: {self.quant_format}")
return QuantizedTensor(qdata, weight._layout_cls, per_expert_params)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = destination if destination is not None else {}
if self.bias is not None:
sd[f"{prefix}bias"] = self.bias
if self.quant_format is None:
if self.weight is not None:
sd[f"{prefix}weight"] = self.weight
if self.weight is None:
return sd
sd[f"{prefix}weight"] = self._qdata
if self.quant_format == "nvfp4":
sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8)
sd[f"{prefix}weight_scale_2"] = self._tensor_scale
elif self.quant_format == "mxfp8":
sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8)
elif self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
sd[f"{prefix}weight_scale"] = self._tensor_scale
quant_conf = {"format": self.quant_format, "num_experts": self.num_experts}
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd[f"{prefix}comfy_quant"] = torch.tensor(
list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8
)
if isinstance(self.weight, QuantizedTensor):
sd.update(self.weight.state_dict(f"{prefix}weight"))
quant_conf = {"format": self.quant_format, "num_experts": self.num_experts}
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd[f"{prefix}comfy_quant"] = torch.tensor(
list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8
)
else:
sd[f"{prefix}weight"] = self.weight
return sd
class Embedding(manual_cast.Embedding):