mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-29 02:17:52 +08:00
Refactor MoEExperts some
This commit is contained in:
parent
c431fd555b
commit
1d48277a1c
199
comfy/ops.py
199
comfy/ops.py
@ -1333,10 +1333,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
self._buffers[key] = fn(buf)
|
self._buffers[key] = fn(buf)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
class MoEExperts(torch.nn.Module):
|
class MoEExperts(CastWeightBiasOp, torch.nn.Module):
|
||||||
"""Container for E quantized expert weights, indexed via ``expert_weight(i)``.
|
"""Container for E quantized expert weights, indexed via ``expert_weight(i)``.
|
||||||
|
|
||||||
Holds expert weights as 3D buffers/parameters.
|
The full bank lives on ``self.weight`` as a single (3D) tensor — either
|
||||||
|
a bf16 ``Parameter`` or a ``Parameter`` wrapping a ``QuantizedTensor``
|
||||||
|
with leading expert dim.
|
||||||
|
|
||||||
State-dict layout (analogous to ``mixed_precision_ops.Linear`` with a
|
State-dict layout (analogous to ``mixed_precision_ops.Linear`` with a
|
||||||
leading expert dim — exact storage shape is layout-specific)::
|
leading expert dim — exact storage shape is layout-specific)::
|
||||||
@ -1362,15 +1364,34 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
# Populated by _load_from_state_dict:
|
# Populated by _load_from_state_dict:
|
||||||
self.weight = None # bf16 fallback: 3D Parameter [E, out, in]
|
self.weight = None
|
||||||
self.quant_format = None
|
self.quant_format = None
|
||||||
self.layout_type = None
|
self.layout_type = None
|
||||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||||
self._full_precision_mm_config = False
|
self._full_precision_mm_config = False
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
|
# No-op so module init doesn't clobber the loaded quant weights.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _apply(self, fn, recurse=True):
|
||||||
|
# Mirror Linear._apply: re-wrap each Parameter so .to()/.cuda()
|
||||||
|
# propagate through the QuantizedTensor wrapped inside self.weight.
|
||||||
|
if recurse:
|
||||||
|
for module in self.children():
|
||||||
|
module._apply(fn)
|
||||||
|
for key, param in self._parameters.items():
|
||||||
|
if param is None:
|
||||||
|
continue
|
||||||
|
p = fn(param)
|
||||||
|
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
||||||
|
p = p.clone()
|
||||||
|
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||||
|
for key, buf in self._buffers.items():
|
||||||
|
if buf is not None:
|
||||||
|
self._buffers[key] = fn(buf)
|
||||||
|
return self
|
||||||
|
|
||||||
def _load_scale_param(self, state_dict, prefix, param_name, device,
|
def _load_scale_param(self, state_dict, prefix, param_name, device,
|
||||||
manually_loaded_keys, dtype=None):
|
manually_loaded_keys, dtype=None):
|
||||||
key = f"{prefix}{param_name}"
|
key = f"{prefix}{param_name}"
|
||||||
@ -1382,7 +1403,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
manually_loaded_keys.append(key)
|
manually_loaded_keys.append(key)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
# TODO: refactor to share more code with Linear._load_from_state_dict
|
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
device = self.factory_kwargs["device"]
|
device = self.factory_kwargs["device"]
|
||||||
layer_name = prefix.rstrip(".")
|
layer_name = prefix.rstrip(".")
|
||||||
@ -1417,31 +1437,44 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
qconfig = QUANT_ALGOS[self.quant_format]
|
qconfig = QUANT_ALGOS[self.quant_format]
|
||||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
|
layout_cls = get_layout_class(self.layout_type)
|
||||||
|
orig_shape = (self.num_experts, self.out_features, self.in_features)
|
||||||
|
|
||||||
|
# Scales keep their leading expert dim; per-expert slicing happens at access.
|
||||||
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||||
ts = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||||
self.register_buffer("_tensor_scale", ts, persistent=False)
|
params = layout_cls.Params(
|
||||||
|
scale=scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=orig_shape,
|
||||||
|
)
|
||||||
elif self.quant_format == "mxfp8":
|
elif self.quant_format == "mxfp8":
|
||||||
bs = self._load_scale_param(state_dict, prefix, "weight_scale", device,
|
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||||
manually_loaded_keys, dtype=torch.uint8)
|
if block_scale is None:
|
||||||
if bs is None:
|
|
||||||
raise ValueError(f"Missing MXFP8 block scales for MoEExperts layer {layer_name}")
|
raise ValueError(f"Missing MXFP8 block scales for MoEExperts layer {layer_name}")
|
||||||
self.register_buffer("_block_scale", bs.view(torch.float8_e8m0fnu), persistent=False)
|
params = layout_cls.Params(
|
||||||
|
scale=block_scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=orig_shape,
|
||||||
|
)
|
||||||
elif self.quant_format == "nvfp4":
|
elif self.quant_format == "nvfp4":
|
||||||
ts = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||||
bs = self._load_scale_param(state_dict, prefix, "weight_scale", device,
|
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||||
manually_loaded_keys, dtype=torch.float8_e4m3fn)
|
if tensor_scale is None or block_scale is None:
|
||||||
if ts is None or bs is None:
|
|
||||||
raise ValueError(f"Missing NVFP4 scales for MoEExperts layer {layer_name}")
|
raise ValueError(f"Missing NVFP4 scales for MoEExperts layer {layer_name}")
|
||||||
self.register_buffer("_tensor_scale", ts, persistent=False)
|
params = layout_cls.Params(
|
||||||
self.register_buffer("_block_scale", bs, persistent=False)
|
scale=tensor_scale,
|
||||||
|
block_scale=block_scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=orig_shape,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported MoEExperts quant format: {self.quant_format}")
|
raise ValueError(f"Unsupported MoEExperts quant format: {self.quant_format}")
|
||||||
|
|
||||||
self.register_buffer(
|
qdata = weight.to(device=device, dtype=qconfig["storage_t"])
|
||||||
"_qdata",
|
self.weight = torch.nn.Parameter(
|
||||||
weight.to(device=device, dtype=qconfig["storage_t"]),
|
QuantizedTensor(qdata, self.layout_type, params),
|
||||||
persistent=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
@ -1451,84 +1484,76 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
missing_keys.remove(k)
|
missing_keys.remove(k)
|
||||||
|
|
||||||
def expert_weight(self, i: int):
|
def expert_weight(self, i: int):
|
||||||
"""Expert i's weight (Tensor or QuantizedTensor)."""
|
"""Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
|
||||||
if self.quant_format is None:
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
return self.weight[i]
|
return self._expert_qt_from(self.weight, i)
|
||||||
|
return self.weight[i]
|
||||||
qdata = self._qdata[i]
|
|
||||||
layout_cls = get_layout_class(self.layout_type)
|
|
||||||
orig_shape = (self.out_features, self.in_features)
|
|
||||||
|
|
||||||
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
|
||||||
scale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale
|
|
||||||
params = layout_cls.Params(
|
|
||||||
scale=scale,
|
|
||||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
|
||||||
orig_shape=orig_shape,
|
|
||||||
)
|
|
||||||
elif self.quant_format == "mxfp8":
|
|
||||||
params = layout_cls.Params(
|
|
||||||
scale=self._block_scale[i],
|
|
||||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
|
||||||
orig_shape=orig_shape,
|
|
||||||
)
|
|
||||||
elif self.quant_format == "nvfp4":
|
|
||||||
tscale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale
|
|
||||||
params = layout_cls.Params(
|
|
||||||
scale=tscale,
|
|
||||||
block_scale=self._block_scale[i],
|
|
||||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
|
||||||
orig_shape=orig_shape,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported quant format: {self.quant_format}")
|
|
||||||
return QuantizedTensor(qdata, self.layout_type, params)
|
|
||||||
|
|
||||||
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
|
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
|
||||||
"""Linear against expert ``i``'s weight (with optional bias)."""
|
"""Linear against expert ``i``'s weight (with optional bias)."""
|
||||||
qw = self.expert_weight(i)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
bias = None
|
try:
|
||||||
if self.bias is not None:
|
if isinstance(weight, QuantizedTensor):
|
||||||
bias = cast_to_input(self.bias[i], input, copy=False)
|
qw = self._expert_qt_from(weight, i)
|
||||||
|
else:
|
||||||
|
qw = weight[i]
|
||||||
|
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
|
||||||
|
|
||||||
if isinstance(qw, QuantizedTensor):
|
if isinstance(qw, QuantizedTensor):
|
||||||
use_fast = (
|
use_fast = (
|
||||||
not self._full_precision_mm
|
not self._full_precision_mm
|
||||||
and qw.layout_cls.supports_fast_matmul()
|
and qw.layout_cls.supports_fast_matmul()
|
||||||
and input.dim() == 2
|
and input.dim() == 2
|
||||||
|
)
|
||||||
|
if use_fast:
|
||||||
|
qin = QuantizedTensor.from_float(input, self.layout_type)
|
||||||
|
return torch.nn.functional.linear(qin, qw, b)
|
||||||
|
out = input @ qw.dequantize().t()
|
||||||
|
return out + b if b is not None else out
|
||||||
|
return torch.nn.functional.linear(input, qw, b)
|
||||||
|
finally:
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
|
||||||
|
def _expert_qt_from(self, weight: "QuantizedTensor", i: int) -> "QuantizedTensor":
|
||||||
|
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
|
||||||
|
qdata = weight._qdata[i]
|
||||||
|
params = weight._params
|
||||||
|
orig_shape = (self.out_features, self.in_features)
|
||||||
|
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||||
|
scale = params.scale[i] if params.scale.dim() else params.scale
|
||||||
|
per_expert_params = type(params)(
|
||||||
|
scale=scale, orig_dtype=params.orig_dtype, orig_shape=orig_shape,
|
||||||
)
|
)
|
||||||
if use_fast:
|
elif self.quant_format == "mxfp8":
|
||||||
qin = QuantizedTensor.from_float(input, self.layout_type)
|
per_expert_params = type(params)(
|
||||||
return torch.nn.functional.linear(qin, qw, bias)
|
scale=params.scale[i], orig_dtype=params.orig_dtype, orig_shape=orig_shape,
|
||||||
out = input @ qw.dequantize().t()
|
)
|
||||||
return out + bias if bias is not None else out
|
elif self.quant_format == "nvfp4":
|
||||||
|
scale = params.scale[i] if params.scale.dim() else params.scale
|
||||||
return torch.nn.functional.linear(input, qw, bias)
|
per_expert_params = type(params)(
|
||||||
|
scale=scale, block_scale=params.block_scale[i],
|
||||||
|
orig_dtype=params.orig_dtype, orig_shape=orig_shape,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quant format: {self.quant_format}")
|
||||||
|
return QuantizedTensor(qdata, weight._layout_cls, per_expert_params)
|
||||||
|
|
||||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
sd = destination if destination is not None else {}
|
sd = destination if destination is not None else {}
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
sd[f"{prefix}bias"] = self.bias
|
sd[f"{prefix}bias"] = self.bias
|
||||||
if self.quant_format is None:
|
if self.weight is None:
|
||||||
if self.weight is not None:
|
|
||||||
sd[f"{prefix}weight"] = self.weight
|
|
||||||
return sd
|
return sd
|
||||||
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
sd[f"{prefix}weight"] = self._qdata
|
sd.update(self.weight.state_dict(f"{prefix}weight"))
|
||||||
if self.quant_format == "nvfp4":
|
quant_conf = {"format": self.quant_format, "num_experts": self.num_experts}
|
||||||
sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8)
|
if self._full_precision_mm_config:
|
||||||
sd[f"{prefix}weight_scale_2"] = self._tensor_scale
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
elif self.quant_format == "mxfp8":
|
sd[f"{prefix}comfy_quant"] = torch.tensor(
|
||||||
sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8)
|
list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8
|
||||||
elif self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
)
|
||||||
sd[f"{prefix}weight_scale"] = self._tensor_scale
|
else:
|
||||||
|
sd[f"{prefix}weight"] = self.weight
|
||||||
quant_conf = {"format": self.quant_format, "num_experts": self.num_experts}
|
|
||||||
if self._full_precision_mm_config:
|
|
||||||
quant_conf["full_precision_matrix_mult"] = True
|
|
||||||
sd[f"{prefix}comfy_quant"] = torch.tensor(
|
|
||||||
list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8
|
|
||||||
)
|
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
class Embedding(manual_cast.Embedding):
|
class Embedding(manual_cast.Embedding):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user