diff --git a/comfy/ops.py b/comfy/ops.py index 3ee7c1216..d2e065d31 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1333,10 +1333,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self._buffers[key] = fn(buf) return self - class MoEExperts(torch.nn.Module): + class MoEExperts(CastWeightBiasOp, torch.nn.Module): """Container for E quantized expert weights, indexed via ``expert_weight(i)``. - Holds expert weights as 3D buffers/parameters. + The full bank lives on ``self.weight`` as a single (3D) tensor — either + a bf16 ``Parameter`` or a ``Parameter`` wrapping a ``QuantizedTensor`` + with leading expert dim. State-dict layout (analogous to ``mixed_precision_ops.Linear`` with a leading expert dim — exact storage shape is layout-specific):: @@ -1362,15 +1364,34 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.register_parameter("bias", None) # Populated by _load_from_state_dict: - self.weight = None # bf16 fallback: 3D Parameter [E, out, in] + self.weight = None self.quant_format = None self.layout_type = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm self._full_precision_mm_config = False def reset_parameters(self): + # No-op so module init doesn't clobber the loaded quant weights. return None + def _apply(self, fn, recurse=True): + # Mirror Linear._apply: re-wrap each Parameter so .to()/.cuda() + # propagate through the QuantizedTensor wrapped inside self.weight. + if recurse: + for module in self.children(): + module._apply(fn) + for key, param in self._parameters.items(): + if param is None: + continue + p = fn(param) + if (not torch.is_inference_mode_enabled()) and p.is_inference(): + p = p.clone() + self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + return self + def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None): key = f"{prefix}{param_name}" @@ -1382,7 +1403,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec manually_loaded_keys.append(key) return value - # TODO: refactor to share more code with Linear._load_from_state_dict def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): device = self.factory_kwargs["device"] layer_name = prefix.rstrip(".") @@ -1417,31 +1437,44 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec qconfig = QUANT_ALGOS[self.quant_format] self.layout_type = qconfig["comfy_tensor_layout"] + layout_cls = get_layout_class(self.layout_type) + orig_shape = (self.num_experts, self.out_features, self.in_features) + # Scales keep their leading expert dim; per-expert slicing happens at access. if self.quant_format in ("float8_e4m3fn", "float8_e5m2"): - ts = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - self.register_buffer("_tensor_scale", ts, persistent=False) + scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) + params = layout_cls.Params( + scale=scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=orig_shape, + ) elif self.quant_format == "mxfp8": - bs = self._load_scale_param(state_dict, prefix, "weight_scale", device, - manually_loaded_keys, dtype=torch.uint8) - if bs is None: + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) + if block_scale is None: raise ValueError(f"Missing MXFP8 block scales for MoEExperts layer {layer_name}") - self.register_buffer("_block_scale", bs.view(torch.float8_e8m0fnu), persistent=False) + params = layout_cls.Params( + scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=orig_shape, + ) elif self.quant_format == "nvfp4": - ts = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) - bs = self._load_scale_param(state_dict, prefix, "weight_scale", device, - manually_loaded_keys, dtype=torch.float8_e4m3fn) - if ts is None or bs is None: + tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) + if tensor_scale is None or block_scale is None: raise ValueError(f"Missing NVFP4 scales for MoEExperts layer {layer_name}") - self.register_buffer("_tensor_scale", ts, persistent=False) - self.register_buffer("_block_scale", bs, persistent=False) + params = layout_cls.Params( + scale=tensor_scale, + block_scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=orig_shape, + ) else: raise ValueError(f"Unsupported MoEExperts quant format: {self.quant_format}") - self.register_buffer( - "_qdata", - weight.to(device=device, dtype=qconfig["storage_t"]), - persistent=False, + qdata = weight.to(device=device, dtype=qconfig["storage_t"]) + self.weight = torch.nn.Parameter( + QuantizedTensor(qdata, self.layout_type, params), + requires_grad=False, ) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, @@ -1451,84 +1484,76 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec missing_keys.remove(k) def expert_weight(self, i: int): - """Expert i's weight (Tensor or QuantizedTensor).""" - if self.quant_format is None: - return self.weight[i] - - qdata = self._qdata[i] - layout_cls = get_layout_class(self.layout_type) - orig_shape = (self.out_features, self.in_features) - - if self.quant_format in ("float8_e4m3fn", "float8_e5m2"): - scale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale - params = layout_cls.Params( - scale=scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=orig_shape, - ) - elif self.quant_format == "mxfp8": - params = layout_cls.Params( - scale=self._block_scale[i], - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=orig_shape, - ) - elif self.quant_format == "nvfp4": - tscale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale - params = layout_cls.Params( - scale=tscale, - block_scale=self._block_scale[i], - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=orig_shape, - ) - else: - raise ValueError(f"Unsupported quant format: {self.quant_format}") - return QuantizedTensor(qdata, self.layout_type, params) + """Expert i's weight (Tensor or per-expert QuantizedTensor view).""" + if isinstance(self.weight, QuantizedTensor): + return self._expert_qt_from(self.weight, i) + return self.weight[i] def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor: """Linear against expert ``i``'s weight (with optional bias).""" - qw = self.expert_weight(i) - bias = None - if self.bias is not None: - bias = cast_to_input(self.bias[i], input, copy=False) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + try: + if isinstance(weight, QuantizedTensor): + qw = self._expert_qt_from(weight, i) + else: + qw = weight[i] + b = cast_to_input(bias[i], input, copy=False) if bias is not None else None - if isinstance(qw, QuantizedTensor): - use_fast = ( - not self._full_precision_mm - and qw.layout_cls.supports_fast_matmul() - and input.dim() == 2 + if isinstance(qw, QuantizedTensor): + use_fast = ( + not self._full_precision_mm + and qw.layout_cls.supports_fast_matmul() + and input.dim() == 2 + ) + if use_fast: + qin = QuantizedTensor.from_float(input, self.layout_type) + return torch.nn.functional.linear(qin, qw, b) + out = input @ qw.dequantize().t() + return out + b if b is not None else out + return torch.nn.functional.linear(input, qw, b) + finally: + uncast_bias_weight(self, weight, bias, offload_stream) + + def _expert_qt_from(self, weight: "QuantizedTensor", i: int) -> "QuantizedTensor": + """Build a per-expert QuantizedTensor by indexing into a resident bank.""" + qdata = weight._qdata[i] + params = weight._params + orig_shape = (self.out_features, self.in_features) + if self.quant_format in ("float8_e4m3fn", "float8_e5m2"): + scale = params.scale[i] if params.scale.dim() else params.scale + per_expert_params = type(params)( + scale=scale, orig_dtype=params.orig_dtype, orig_shape=orig_shape, ) - if use_fast: - qin = QuantizedTensor.from_float(input, self.layout_type) - return torch.nn.functional.linear(qin, qw, bias) - out = input @ qw.dequantize().t() - return out + bias if bias is not None else out - - return torch.nn.functional.linear(input, qw, bias) + elif self.quant_format == "mxfp8": + per_expert_params = type(params)( + scale=params.scale[i], orig_dtype=params.orig_dtype, orig_shape=orig_shape, + ) + elif self.quant_format == "nvfp4": + scale = params.scale[i] if params.scale.dim() else params.scale + per_expert_params = type(params)( + scale=scale, block_scale=params.block_scale[i], + orig_dtype=params.orig_dtype, orig_shape=orig_shape, + ) + else: + raise ValueError(f"Unsupported quant format: {self.quant_format}") + return QuantizedTensor(qdata, weight._layout_cls, per_expert_params) def state_dict(self, *args, destination=None, prefix="", **kwargs): sd = destination if destination is not None else {} if self.bias is not None: sd[f"{prefix}bias"] = self.bias - if self.quant_format is None: - if self.weight is not None: - sd[f"{prefix}weight"] = self.weight + if self.weight is None: return sd - - sd[f"{prefix}weight"] = self._qdata - if self.quant_format == "nvfp4": - sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8) - sd[f"{prefix}weight_scale_2"] = self._tensor_scale - elif self.quant_format == "mxfp8": - sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8) - elif self.quant_format in ("float8_e4m3fn", "float8_e5m2"): - sd[f"{prefix}weight_scale"] = self._tensor_scale - - quant_conf = {"format": self.quant_format, "num_experts": self.num_experts} - if self._full_precision_mm_config: - quant_conf["full_precision_matrix_mult"] = True - sd[f"{prefix}comfy_quant"] = torch.tensor( - list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8 - ) + if isinstance(self.weight, QuantizedTensor): + sd.update(self.weight.state_dict(f"{prefix}weight")) + quant_conf = {"format": self.quant_format, "num_experts": self.num_experts} + if self._full_precision_mm_config: + quant_conf["full_precision_matrix_mult"] = True + sd[f"{prefix}comfy_quant"] = torch.tensor( + list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8 + ) + else: + sd[f"{prefix}weight"] = self.weight return sd class Embedding(manual_cast.Embedding):