Refactor MoEExperts some

This commit is contained in:
kijai 2026-05-24 03:48:47 +03:00
parent c431fd555b
commit 1d48277a1c

View File

@ -1333,10 +1333,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self._buffers[key] = fn(buf) self._buffers[key] = fn(buf)
return self return self
class MoEExperts(torch.nn.Module): class MoEExperts(CastWeightBiasOp, torch.nn.Module):
"""Container for E quantized expert weights, indexed via ``expert_weight(i)``. """Container for E quantized expert weights, indexed via ``expert_weight(i)``.
Holds expert weights as 3D buffers/parameters. The full bank lives on ``self.weight`` as a single (3D) tensor either
a bf16 ``Parameter`` or a ``Parameter`` wrapping a ``QuantizedTensor``
with leading expert dim.
State-dict layout (analogous to ``mixed_precision_ops.Linear`` with a State-dict layout (analogous to ``mixed_precision_ops.Linear`` with a
leading expert dim exact storage shape is layout-specific):: leading expert dim exact storage shape is layout-specific)::
@ -1362,15 +1364,34 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.register_parameter("bias", None) self.register_parameter("bias", None)
# Populated by _load_from_state_dict: # Populated by _load_from_state_dict:
self.weight = None # bf16 fallback: 3D Parameter [E, out, in] self.weight = None
self.quant_format = None self.quant_format = None
self.layout_type = None self.layout_type = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm self._full_precision_mm = MixedPrecisionOps._full_precision_mm
self._full_precision_mm_config = False self._full_precision_mm_config = False
def reset_parameters(self): def reset_parameters(self):
# No-op so module init doesn't clobber the loaded quant weights.
return None return None
def _apply(self, fn, recurse=True):
# Mirror Linear._apply: re-wrap each Parameter so .to()/.cuda()
# propagate through the QuantizedTensor wrapped inside self.weight.
if recurse:
for module in self.children():
module._apply(fn)
for key, param in self._parameters.items():
if param is None:
continue
p = fn(param)
if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
def _load_scale_param(self, state_dict, prefix, param_name, device, def _load_scale_param(self, state_dict, prefix, param_name, device,
manually_loaded_keys, dtype=None): manually_loaded_keys, dtype=None):
key = f"{prefix}{param_name}" key = f"{prefix}{param_name}"
@ -1382,7 +1403,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
manually_loaded_keys.append(key) manually_loaded_keys.append(key)
return value return value
# TODO: refactor to share more code with Linear._load_from_state_dict
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"] device = self.factory_kwargs["device"]
layer_name = prefix.rstrip(".") layer_name = prefix.rstrip(".")
@ -1417,31 +1437,44 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
qconfig = QUANT_ALGOS[self.quant_format] qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"] self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
orig_shape = (self.num_experts, self.out_features, self.in_features)
# Scales keep their leading expert dim; per-expert slicing happens at access.
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"): if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
ts = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
self.register_buffer("_tensor_scale", ts, persistent=False) params = layout_cls.Params(
scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
elif self.quant_format == "mxfp8": elif self.quant_format == "mxfp8":
bs = self._load_scale_param(state_dict, prefix, "weight_scale", device, block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
manually_loaded_keys, dtype=torch.uint8) if block_scale is None:
if bs is None:
raise ValueError(f"Missing MXFP8 block scales for MoEExperts layer {layer_name}") raise ValueError(f"Missing MXFP8 block scales for MoEExperts layer {layer_name}")
self.register_buffer("_block_scale", bs.view(torch.float8_e8m0fnu), persistent=False) params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
elif self.quant_format == "nvfp4": elif self.quant_format == "nvfp4":
ts = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
bs = self._load_scale_param(state_dict, prefix, "weight_scale", device, block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
manually_loaded_keys, dtype=torch.float8_e4m3fn) if tensor_scale is None or block_scale is None:
if ts is None or bs is None:
raise ValueError(f"Missing NVFP4 scales for MoEExperts layer {layer_name}") raise ValueError(f"Missing NVFP4 scales for MoEExperts layer {layer_name}")
self.register_buffer("_tensor_scale", ts, persistent=False) params = layout_cls.Params(
self.register_buffer("_block_scale", bs, persistent=False) scale=tensor_scale,
block_scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
else: else:
raise ValueError(f"Unsupported MoEExperts quant format: {self.quant_format}") raise ValueError(f"Unsupported MoEExperts quant format: {self.quant_format}")
self.register_buffer( qdata = weight.to(device=device, dtype=qconfig["storage_t"])
"_qdata", self.weight = torch.nn.Parameter(
weight.to(device=device, dtype=qconfig["storage_t"]), QuantizedTensor(qdata, self.layout_type, params),
persistent=False, requires_grad=False,
) )
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
@ -1451,84 +1484,76 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
missing_keys.remove(k) missing_keys.remove(k)
def expert_weight(self, i: int): def expert_weight(self, i: int):
"""Expert i's weight (Tensor or QuantizedTensor).""" """Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
if self.quant_format is None: if isinstance(self.weight, QuantizedTensor):
return self.weight[i] return self._expert_qt_from(self.weight, i)
return self.weight[i]
qdata = self._qdata[i]
layout_cls = get_layout_class(self.layout_type)
orig_shape = (self.out_features, self.in_features)
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
scale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale
params = layout_cls.Params(
scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
elif self.quant_format == "mxfp8":
params = layout_cls.Params(
scale=self._block_scale[i],
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
elif self.quant_format == "nvfp4":
tscale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale
params = layout_cls.Params(
scale=tscale,
block_scale=self._block_scale[i],
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
else:
raise ValueError(f"Unsupported quant format: {self.quant_format}")
return QuantizedTensor(qdata, self.layout_type, params)
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor: def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
"""Linear against expert ``i``'s weight (with optional bias).""" """Linear against expert ``i``'s weight (with optional bias)."""
qw = self.expert_weight(i) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
bias = None try:
if self.bias is not None: if isinstance(weight, QuantizedTensor):
bias = cast_to_input(self.bias[i], input, copy=False) qw = self._expert_qt_from(weight, i)
else:
qw = weight[i]
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
if isinstance(qw, QuantizedTensor): if isinstance(qw, QuantizedTensor):
use_fast = ( use_fast = (
not self._full_precision_mm not self._full_precision_mm
and qw.layout_cls.supports_fast_matmul() and qw.layout_cls.supports_fast_matmul()
and input.dim() == 2 and input.dim() == 2
)
if use_fast:
qin = QuantizedTensor.from_float(input, self.layout_type)
return torch.nn.functional.linear(qin, qw, b)
out = input @ qw.dequantize().t()
return out + b if b is not None else out
return torch.nn.functional.linear(input, qw, b)
finally:
uncast_bias_weight(self, weight, bias, offload_stream)
def _expert_qt_from(self, weight: "QuantizedTensor", i: int) -> "QuantizedTensor":
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
qdata = weight._qdata[i]
params = weight._params
orig_shape = (self.out_features, self.in_features)
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
scale = params.scale[i] if params.scale.dim() else params.scale
per_expert_params = type(params)(
scale=scale, orig_dtype=params.orig_dtype, orig_shape=orig_shape,
) )
if use_fast: elif self.quant_format == "mxfp8":
qin = QuantizedTensor.from_float(input, self.layout_type) per_expert_params = type(params)(
return torch.nn.functional.linear(qin, qw, bias) scale=params.scale[i], orig_dtype=params.orig_dtype, orig_shape=orig_shape,
out = input @ qw.dequantize().t() )
return out + bias if bias is not None else out elif self.quant_format == "nvfp4":
scale = params.scale[i] if params.scale.dim() else params.scale
return torch.nn.functional.linear(input, qw, bias) per_expert_params = type(params)(
scale=scale, block_scale=params.block_scale[i],
orig_dtype=params.orig_dtype, orig_shape=orig_shape,
)
else:
raise ValueError(f"Unsupported quant format: {self.quant_format}")
return QuantizedTensor(qdata, weight._layout_cls, per_expert_params)
def state_dict(self, *args, destination=None, prefix="", **kwargs): def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = destination if destination is not None else {} sd = destination if destination is not None else {}
if self.bias is not None: if self.bias is not None:
sd[f"{prefix}bias"] = self.bias sd[f"{prefix}bias"] = self.bias
if self.quant_format is None: if self.weight is None:
if self.weight is not None:
sd[f"{prefix}weight"] = self.weight
return sd return sd
if isinstance(self.weight, QuantizedTensor):
sd[f"{prefix}weight"] = self._qdata sd.update(self.weight.state_dict(f"{prefix}weight"))
if self.quant_format == "nvfp4": quant_conf = {"format": self.quant_format, "num_experts": self.num_experts}
sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8) if self._full_precision_mm_config:
sd[f"{prefix}weight_scale_2"] = self._tensor_scale quant_conf["full_precision_matrix_mult"] = True
elif self.quant_format == "mxfp8": sd[f"{prefix}comfy_quant"] = torch.tensor(
sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8) list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8
elif self.quant_format in ("float8_e4m3fn", "float8_e5m2"): )
sd[f"{prefix}weight_scale"] = self._tensor_scale else:
sd[f"{prefix}weight"] = self.weight
quant_conf = {"format": self.quant_format, "num_experts": self.num_experts}
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd[f"{prefix}comfy_quant"] = torch.tensor(
list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8
)
return sd return sd
class Embedding(manual_cast.Embedding): class Embedding(manual_cast.Embedding):