diff --git a/comfy/ops.py b/comfy/ops.py index d2e065d31..56445be8d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -18,6 +18,7 @@ import torch import logging +import contextlib import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float @@ -1047,6 +1048,144 @@ class QuantLinearFunc(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None +# Quantized-weight module helpers + +def _quantized_apply(module, fn, recurse=True): + """Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights.""" + if recurse: + for child in module.children(): + child._apply(fn) + for key, param in module._parameters.items(): + if param is None: + continue + p = fn(param) + if (not torch.is_inference_mode_enabled()) and p.is_inference(): + p = p.clone() + module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) + for key, buf in module._buffers.items(): + if buf is not None: + module._buffers[key] = fn(buf) + return module + + +def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs, load_extra_params=False): + """Shared _load_from_state_dict body for quantized-weight modules. + + Pops weight (+ scales, +/- extras), populates module.weight as a Parameter + or Parameter-wrapped QuantizedTensor, then calls super_load and strips + consumed keys from missing_keys. Reads compute_dtype from factory_kwargs + and disabled formats from module._disabled_formats. + """ + device = module.factory_kwargs["device"] + compute_dtype = module.factory_kwargs["dtype"] + disabled_formats = module._disabled_formats + layer_name = prefix.rstrip('.') + + weight = state_dict.pop(f"{prefix}weight", None) + if weight is None: + logging.warning(f"Missing weight for layer {layer_name}") + module.weight = None + return + manually_loaded_keys = [f"{prefix}weight"] + + def pop_scale(name, dtype=None): + key = f"{prefix}{name}" + v = state_dict.pop(key, None) + if v is not None: + v = v.to(device=device) + if dtype is not None: + v = v.view(dtype=dtype) + manually_loaded_keys.append(key) + return v + + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + if layer_conf is None: + module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False) + else: + module.quant_format = layer_conf.get("format", None) + module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) + if not module._full_precision_mm: + module._full_precision_mm = module._full_precision_mm_config + if module.quant_format in disabled_formats: + module._full_precision_mm = True + if module.quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") + + qconfig = QUANT_ALGOS[module.quant_format] + module.layout_type = qconfig["comfy_tensor_layout"] + layout_cls = get_layout_class(module.layout_type) + + # Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8. + if module.quant_format in ("float8_e4m3fn", "float8_e5m2"): + scales = {"scale": pop_scale("weight_scale")} + elif module.quant_format == "mxfp8": + bs = pop_scale("weight_scale", torch.float8_e8m0fnu) + if bs is None: + raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") + scales = {"scale": bs} + elif module.quant_format == "nvfp4": + ts = pop_scale("weight_scale_2") + bs = pop_scale("weight_scale", torch.float8_e4m3fn) + if ts is None or bs is None: + raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") + scales = {"scale": ts, "block_scale": bs} + else: + raise ValueError(f"Unsupported quantization format: {module.quant_format}") + + params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape) + module.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params), + requires_grad=False, + ) + + if load_extra_params: + for param_name in qconfig["parameters"]: + if param_name in {"weight_scale", "weight_scale_2"}: + continue + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) + + super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) + + +def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()): + """Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON; + extra_quant_params names attributes written as additional top-level keys.""" + if not hasattr(module, 'weight'): + logging.warning(f"Warning: state dict on uninitialized op {prefix}") + return sd + bias = getattr(module, 'bias', None) + if bias is not None: + sd[f"{prefix}bias"] = bias + if module.weight is None: + return sd + if isinstance(module.weight, QuantizedTensor): + sd.update(module.weight.state_dict(f"{prefix}weight")) + quant_conf = {"format": module.quant_format} + if getattr(module, '_full_precision_mm_config', False): + quant_conf["full_precision_matrix_mult"] = True + if extra_quant_conf: + quant_conf.update(extra_quant_conf) + sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8) + for name in extra_quant_params: + value = getattr(module, name, None) + if value is not None: + sd[f"{prefix}{name}"] = value + else: + sd[f"{prefix}weight"] = module.weight + return sd + def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): @@ -1056,21 +1195,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec _disabled = disabled class Linear(torch.nn.Module, CastWeightBiasOp): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: + _disabled_formats = disabled + + def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): super().__init__() self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} - # self.factory_kwargs = {"device": device, "dtype": dtype} self.in_features = in_features self.out_features = out_features + self._orig_shape = (out_features, in_features) if bias: self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) else: @@ -1083,151 +1217,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def reset_parameters(self): return None - def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None): - key = f"{prefix}{param_name}" - value = state_dict.pop(key, None) - if value is not None: - value = value.to(device=device) - if dtype is not None: - value = value.view(dtype=dtype) - manually_loaded_keys.append(key) - return value - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs): - - device = self.factory_kwargs["device"] - layer_name = prefix.rstrip('.') - weight_key = f"{prefix}weight" - weight = state_dict.pop(weight_key, None) - if weight is None: - logging.warning(f"Missing weight for layer {layer_name}") - self.weight = None - return - - manually_loaded_keys = [weight_key] - - layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) - if layer_conf is not None: - layer_conf = json.loads(layer_conf.numpy().tobytes()) - - if layer_conf is None: - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) - else: - self.quant_format = layer_conf.get("format", None) - self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) - if not self._full_precision_mm: - self._full_precision_mm = self._full_precision_mm_config - - if self.quant_format in MixedPrecisionOps._disabled: - self._full_precision_mm = True - - if self.quant_format is None: - raise ValueError(f"Unknown quantization format for layer {layer_name}") - - qconfig = QUANT_ALGOS[self.quant_format] - self.layout_type = qconfig["comfy_tensor_layout"] - layout_cls = get_layout_class(self.layout_type) - - # Load format-specific parameters - if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]: - # FP8: single tensor scale - scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - - params = layout_cls.Params( - scale=scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=(self.out_features, self.in_features), - ) - - elif self.quant_format == "mxfp8": - # MXFP8: E8M0 block scales stored as uint8 in safetensors - block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, - dtype=torch.uint8) - - if block_scale is None: - raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") - - block_scale = block_scale.view(torch.float8_e8m0fnu) - - params = layout_cls.Params( - scale=block_scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=(self.out_features, self.in_features), - ) - - elif self.quant_format == "nvfp4": - # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) - tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) - block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, - dtype=torch.float8_e4m3fn) - - if tensor_scale is None or block_scale is None: - raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") - - params = layout_cls.Params( - scale=tensor_scale, - block_scale=block_scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=(self.out_features, self.in_features), - ) - else: - raise ValueError(f"Unsupported quantization format: {self.quant_format}") - - self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), - requires_grad=False - ) - - for param_name in qconfig["parameters"]: - if param_name in {"weight_scale", "weight_scale_2"}: - continue # Already handled above - - param_key = f"{prefix}{param_name}" - _v = state_dict.pop(param_key, None) - if _v is None: - continue - self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) - manually_loaded_keys.append(param_key) - - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - for key in manually_loaded_keys: - if key in missing_keys: - missing_keys.remove(key) + def _load_from_state_dict(self, *args): + _load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True) def state_dict(self, *args, destination=None, prefix="", **kwargs): - if destination is not None: - sd = destination - else: - sd = {} - - if not hasattr(self, 'weight'): - logging.warning("Warning: state dict on uninitialized op {}".format(prefix)) - return sd - - if self.bias is not None: - sd["{}bias".format(prefix)] = self.bias - - if self.weight is None: - return sd - - if isinstance(self.weight, QuantizedTensor): - sd_out = self.weight.state_dict("{}weight".format(prefix)) - for k in sd_out: - sd[k] = sd_out[k] - - quant_conf = {"format": self.quant_format} - if self._full_precision_mm_config: - quant_conf["full_precision_matrix_mult"] = True - sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) - - input_scale = getattr(self, 'input_scale', None) - if input_scale is not None: - sd["{}input_scale".format(prefix)] = input_scale - else: - sd["{}weight".format(prefix)] = self.weight - return sd + sd = destination if destination is not None else {} + return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",)) def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) @@ -1317,46 +1312,34 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter(weight, requires_grad=False) def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working - if recurse: - for module in self.children(): - module._apply(fn) + return _quantized_apply(self, fn, recurse) - for key, param in self._parameters.items(): - if param is None: - continue - p = fn(param) - if (not torch.is_inference_mode_enabled()) and p.is_inference(): - p = p.clone() - self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) - for key, buf in self._buffers.items(): - if buf is not None: - self._buffers[key] = fn(buf) - return self + class MoEExperts(torch.nn.Module, CastWeightBiasOp): + """Container for E quantized expert weights, indexed via expert_weight(i). - class MoEExperts(CastWeightBiasOp, torch.nn.Module): - """Container for E quantized expert weights, indexed via ``expert_weight(i)``. - - The full bank lives on ``self.weight`` as a single (3D) tensor — either - a bf16 ``Parameter`` or a ``Parameter`` wrapping a ``QuantizedTensor`` + The bank lives on self.weight as a single 3D tensor — either a + compute_dtype Parameter or a Parameter wrapping a QuantizedTensor with leading expert dim. - State-dict layout (analogous to ``mixed_precision_ops.Linear`` with a - leading expert dim — exact storage shape is layout-specific):: - + State-dict layout matches mixed_precision_ops.Linear with a leading + expert dim: {prefix}.weight quant data (storage_t), leading dim = E {prefix}.weight_scale block / per-tensor scale {prefix}.weight_scale_2 [E] or scalar NVFP4 only - {prefix}.bias [E, out_features] optional, bf16 + {prefix}.bias [E, out_features] optional, compute_dtype {prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}} - Without ``comfy_quant`` the weight loads as a plain bf16 3D Parameter ``[E, out, in]``. + Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in]. """ + _disabled_formats = disabled + def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): super().__init__() self.num_experts = num_experts self.in_features = in_features self.out_features = out_features + self._orig_shape = (num_experts, out_features, in_features) self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} if bias: self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs)) @@ -1369,119 +1352,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.layout_type = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm self._full_precision_mm_config = False + self._resident_bank = None def reset_parameters(self): - # No-op so module init doesn't clobber the loaded quant weights. return None def _apply(self, fn, recurse=True): - # Mirror Linear._apply: re-wrap each Parameter so .to()/.cuda() - # propagate through the QuantizedTensor wrapped inside self.weight. - if recurse: - for module in self.children(): - module._apply(fn) - for key, param in self._parameters.items(): - if param is None: - continue - p = fn(param) - if (not torch.is_inference_mode_enabled()) and p.is_inference(): - p = p.clone() - self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) - for key, buf in self._buffers.items(): - if buf is not None: - self._buffers[key] = fn(buf) - return self + return _quantized_apply(self, fn, recurse) - def _load_scale_param(self, state_dict, prefix, param_name, device, - manually_loaded_keys, dtype=None): - key = f"{prefix}{param_name}" - value = state_dict.pop(key, None) - if value is not None: - value = value.to(device=device) - if dtype is not None: - value = value.view(dtype=dtype) - manually_loaded_keys.append(key) - return value - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - device = self.factory_kwargs["device"] - layer_name = prefix.rstrip(".") - weight_key = f"{prefix}weight" - weight = state_dict.pop(weight_key, None) - if weight is None: - logging.warning(f"Missing weight for MoEExperts layer {layer_name}") - return - manually_loaded_keys = [weight_key] - - layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) - if layer_conf is not None: - layer_conf = json.loads(layer_conf.numpy().tobytes()) - manually_loaded_keys.append(f"{prefix}comfy_quant") - - if layer_conf is None: - self.weight = torch.nn.Parameter( - weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), - requires_grad=False, - ) - else: - self.quant_format = layer_conf.get("format") - self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) - if not self._full_precision_mm: - self._full_precision_mm = self._full_precision_mm_config - - if self.quant_format in MixedPrecisionOps._disabled: - self._full_precision_mm = True - - if self.quant_format is None: - raise ValueError(f"Unknown quant format for MoEExperts layer {layer_name}") - - qconfig = QUANT_ALGOS[self.quant_format] - self.layout_type = qconfig["comfy_tensor_layout"] - layout_cls = get_layout_class(self.layout_type) - orig_shape = (self.num_experts, self.out_features, self.in_features) - - # Scales keep their leading expert dim; per-expert slicing happens at access. - if self.quant_format in ("float8_e4m3fn", "float8_e5m2"): - scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - params = layout_cls.Params( - scale=scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=orig_shape, - ) - elif self.quant_format == "mxfp8": - block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - if block_scale is None: - raise ValueError(f"Missing MXFP8 block scales for MoEExperts layer {layer_name}") - params = layout_cls.Params( - scale=block_scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=orig_shape, - ) - elif self.quant_format == "nvfp4": - tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) - block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - if tensor_scale is None or block_scale is None: - raise ValueError(f"Missing NVFP4 scales for MoEExperts layer {layer_name}") - params = layout_cls.Params( - scale=tensor_scale, - block_scale=block_scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=orig_shape, - ) - else: - raise ValueError(f"Unsupported MoEExperts quant format: {self.quant_format}") - - qdata = weight.to(device=device, dtype=qconfig["storage_t"]) - self.weight = torch.nn.Parameter( - QuantizedTensor(qdata, self.layout_type, params), - requires_grad=False, - ) - - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) - for k in manually_loaded_keys: - if k in missing_keys: - missing_keys.remove(k) + def _load_from_state_dict(self, *args): + _load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False) def expert_weight(self, i: int): """Expert i's weight (Tensor or per-expert QuantizedTensor view).""" @@ -1489,76 +1369,69 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec return self._expert_qt_from(self.weight, i) return self.weight[i] + @contextlib.contextmanager + def bank_resident(self, input): + """Cast the whole bank once; expert_linear inside reuses the cast. + Not re-entrant — do not nest calls on the same instance. + """ + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + self._resident_bank = (weight, bias) + try: + yield self + finally: + self._resident_bank = None + uncast_bias_weight(self, weight, bias, offload_stream) + def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor: - """Linear against expert ``i``'s weight (with optional bias).""" + """Linear against expert i's weight (with optional bias).""" + resident = getattr(self, "_resident_bank", None) + if resident is not None: + weight, bias = resident + return self._expert_linear_impl(input, weight, bias, i) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) try: - if isinstance(weight, QuantizedTensor): - qw = self._expert_qt_from(weight, i) - else: - qw = weight[i] - b = cast_to_input(bias[i], input, copy=False) if bias is not None else None - - if isinstance(qw, QuantizedTensor): - use_fast = ( - not self._full_precision_mm - and qw.layout_cls.supports_fast_matmul() - and input.dim() == 2 - ) - if use_fast: - qin = QuantizedTensor.from_float(input, self.layout_type) - return torch.nn.functional.linear(qin, qw, b) - out = input @ qw.dequantize().t() - return out + b if b is not None else out - return torch.nn.functional.linear(input, qw, b) + return self._expert_linear_impl(input, weight, bias, i) finally: uncast_bias_weight(self, weight, bias, offload_stream) - def _expert_qt_from(self, weight: "QuantizedTensor", i: int) -> "QuantizedTensor": - """Build a per-expert QuantizedTensor by indexing into a resident bank.""" - qdata = weight._qdata[i] - params = weight._params - orig_shape = (self.out_features, self.in_features) - if self.quant_format in ("float8_e4m3fn", "float8_e5m2"): - scale = params.scale[i] if params.scale.dim() else params.scale - per_expert_params = type(params)( - scale=scale, orig_dtype=params.orig_dtype, orig_shape=orig_shape, - ) - elif self.quant_format == "mxfp8": - per_expert_params = type(params)( - scale=params.scale[i], orig_dtype=params.orig_dtype, orig_shape=orig_shape, - ) - elif self.quant_format == "nvfp4": - scale = params.scale[i] if params.scale.dim() else params.scale - per_expert_params = type(params)( - scale=scale, block_scale=params.block_scale[i], - orig_dtype=params.orig_dtype, orig_shape=orig_shape, - ) + def _expert_linear_impl(self, input, weight, bias, i): + if isinstance(weight, QuantizedTensor): + qw = self._expert_qt_from(weight, i) else: - raise ValueError(f"Unsupported quant format: {self.quant_format}") - return QuantizedTensor(qdata, weight._layout_cls, per_expert_params) + qw = weight[i] + b = cast_to_input(bias[i], input, copy=False) if bias is not None else None + + if isinstance(qw, QuantizedTensor): + use_fast = ( + not self._full_precision_mm + and qw.layout_cls.supports_fast_matmul() + and input.dim() == 2 + ) + if use_fast: + qin = QuantizedTensor.from_float(input, self.layout_type) + return torch.nn.functional.linear(qin, qw, b) + out = input @ qw.dequantize().t() + return out + b if b is not None else out + return torch.nn.functional.linear(input, qw, b) + + def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor: + """Build a per-expert QuantizedTensor by indexing into a resident bank.""" + params = weight._params + kwargs = { + "scale": params.scale[i] if params.scale.dim() else params.scale, + "orig_dtype": params.orig_dtype, + "orig_shape": (self.out_features, self.in_features), + } + if hasattr(params, "block_scale"): # NVFP4 + kwargs["block_scale"] = params.block_scale[i] + return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs)) def state_dict(self, *args, destination=None, prefix="", **kwargs): sd = destination if destination is not None else {} - if self.bias is not None: - sd[f"{prefix}bias"] = self.bias - if self.weight is None: - return sd - if isinstance(self.weight, QuantizedTensor): - sd.update(self.weight.state_dict(f"{prefix}weight")) - quant_conf = {"format": self.quant_format, "num_experts": self.num_experts} - if self._full_precision_mm_config: - quant_conf["full_precision_matrix_mult"] = True - sd[f"{prefix}comfy_quant"] = torch.tensor( - list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8 - ) - else: - sd[f"{prefix}weight"] = self.weight - return sd + return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts}) class Embedding(manual_cast.Embedding): - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): weight_key = f"{prefix}weight" layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) if layer_conf is not None: @@ -1566,14 +1439,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec # Only fp8 makes sense for embeddings (per-row dequant via index select). # Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently. - quant_format = layer_conf.get("format", None) if layer_conf is not None else None - if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict: + quant_format = layer_conf.get("format") if layer_conf is not None else None + manually_loaded_keys = [] + + if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict: self.quant_format = quant_format qconfig = QUANT_ALGOS[quant_format] self.layout_type = qconfig["comfy_tensor_layout"] layout_cls = get_layout_class(self.layout_type) weight = state_dict.pop(weight_key) - manually_loaded_keys = [weight_key] + manually_loaded_keys.append(weight_key) scale_key = f"{prefix}weight_scale" scale = state_dict.pop(scale_key, None) @@ -1589,35 +1464,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter( QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params), requires_grad=False) + elif layer_conf is not None: + # Unsupported format — restore the marker so it round-trips; fall through to default load. + state_dict[f"{prefix}comfy_quant"] = torch.tensor( + list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - for k in manually_loaded_keys: - if k in missing_keys: - missing_keys.remove(k) - else: - if layer_conf is not None: - state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for k in manually_loaded_keys: + if k in missing_keys: + missing_keys.remove(k) def state_dict(self, *args, destination=None, prefix="", **kwargs): - if destination is not None: - sd = destination - else: - sd = {} - - if not hasattr(self, 'weight') or self.weight is None: - return sd - - if isinstance(self.weight, QuantizedTensor): - sd_out = self.weight.state_dict("{}weight".format(prefix)) - for k in sd_out: - sd[k] = sd_out[k] - - quant_conf = {"format": self.quant_format} - sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) - else: - sd["{}weight".format(prefix)] = self.weight - return sd + sd = destination if destination is not None else {} + return _quantized_weight_state_dict(self, sd, prefix) def forward_comfy_cast_weights(self, input, out_dtype=None): weight = self.weight diff --git a/comfy/text_encoders/gpt_oss.py b/comfy/text_encoders/gpt_oss.py index 2453d8d74..d596ef9a0 100644 --- a/comfy/text_encoders/gpt_oss.py +++ b/comfy/text_encoders/gpt_oss.py @@ -218,19 +218,21 @@ class GptOssExperts(nn.Module): expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for ei in expert_hit: - expert_idx = int(ei.item()) - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) - current = hidden_states[token_idx] + with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \ + self.down_proj.bank_resident(hidden_states) as down_bank: + for ei in expert_hit: + expert_idx = int(ei.item()) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current = hidden_states[token_idx] - gate_up = self.gate_up_proj.expert_linear(current, expert_idx) - gated = self._apply_gate(gate_up) - expert_out = self.down_proj.expert_linear(gated, expert_idx) + gate_up = gate_up_bank.expert_linear(current, expert_idx) + gated = self._apply_gate(gate_up) + expert_out = down_bank.expert_linear(gated, expert_idx) - weighted = expert_out * routing_weights[token_idx, top_k_pos, None] + weighted = expert_out * routing_weights[token_idx, top_k_pos, None] - flat_idx = token_idx * top_k + top_k_pos - per_pair[flat_idx] = weighted.to(per_pair.dtype) + flat_idx = token_idx * top_k + top_k_pos + per_pair[flat_idx] = weighted.to(per_pair.dtype) return per_pair.view(N, top_k, H).sum(dim=1)