Pivot from using mxfp4 to nvfp4

This commit is contained in:
kijai 2026-05-24 02:32:02 +03:00
parent b429028ad2
commit c431fd555b
4 changed files with 216 additions and 198 deletions

View File

@ -1333,6 +1333,204 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self._buffers[key] = fn(buf) self._buffers[key] = fn(buf)
return self return self
class MoEExperts(torch.nn.Module):
"""Container for E quantized expert weights, indexed via ``expert_weight(i)``.
Holds expert weights as 3D buffers/parameters.
State-dict layout (analogous to ``mixed_precision_ops.Linear`` with a
leading expert dim exact storage shape is layout-specific)::
{prefix}.weight quant data (storage_t), leading dim = E
{prefix}.weight_scale block / per-tensor scale
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
{prefix}.bias [E, out_features] optional, bf16
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
Without ``comfy_quant`` the weight loads as a plain bf16 3D Parameter ``[E, out, in]``.
"""
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
super().__init__()
self.num_experts = num_experts
self.in_features = in_features
self.out_features = out_features
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
if bias:
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
# Populated by _load_from_state_dict:
self.weight = None # bf16 fallback: 3D Parameter [E, out, in]
self.quant_format = None
self.layout_type = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
self._full_precision_mm_config = False
def reset_parameters(self):
return None
def _load_scale_param(self, state_dict, prefix, param_name, device,
manually_loaded_keys, dtype=None):
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
# TODO: refactor to share more code with Linear._load_from_state_dict
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip(".")
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
logging.warning(f"Missing weight for MoEExperts layer {layer_name}")
return
manually_loaded_keys = [weight_key]
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
manually_loaded_keys.append(f"{prefix}comfy_quant")
if layer_conf is None:
self.weight = torch.nn.Parameter(
weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype),
requires_grad=False,
)
else:
self.quant_format = layer_conf.get("format")
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not self._full_precision_mm:
self._full_precision_mm = self._full_precision_mm_config
if self.quant_format in MixedPrecisionOps._disabled:
self._full_precision_mm = True
if self.quant_format is None:
raise ValueError(f"Unknown quant format for MoEExperts layer {layer_name}")
qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
ts = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
self.register_buffer("_tensor_scale", ts, persistent=False)
elif self.quant_format == "mxfp8":
bs = self._load_scale_param(state_dict, prefix, "weight_scale", device,
manually_loaded_keys, dtype=torch.uint8)
if bs is None:
raise ValueError(f"Missing MXFP8 block scales for MoEExperts layer {layer_name}")
self.register_buffer("_block_scale", bs.view(torch.float8_e8m0fnu), persistent=False)
elif self.quant_format == "nvfp4":
ts = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
bs = self._load_scale_param(state_dict, prefix, "weight_scale", device,
manually_loaded_keys, dtype=torch.float8_e4m3fn)
if ts is None or bs is None:
raise ValueError(f"Missing NVFP4 scales for MoEExperts layer {layer_name}")
self.register_buffer("_tensor_scale", ts, persistent=False)
self.register_buffer("_block_scale", bs, persistent=False)
else:
raise ValueError(f"Unsupported MoEExperts quant format: {self.quant_format}")
self.register_buffer(
"_qdata",
weight.to(device=device, dtype=qconfig["storage_t"]),
persistent=False,
)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
def expert_weight(self, i: int):
"""Expert i's weight (Tensor or QuantizedTensor)."""
if self.quant_format is None:
return self.weight[i]
qdata = self._qdata[i]
layout_cls = get_layout_class(self.layout_type)
orig_shape = (self.out_features, self.in_features)
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
scale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale
params = layout_cls.Params(
scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
elif self.quant_format == "mxfp8":
params = layout_cls.Params(
scale=self._block_scale[i],
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
elif self.quant_format == "nvfp4":
tscale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale
params = layout_cls.Params(
scale=tscale,
block_scale=self._block_scale[i],
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=orig_shape,
)
else:
raise ValueError(f"Unsupported quant format: {self.quant_format}")
return QuantizedTensor(qdata, self.layout_type, params)
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
"""Linear against expert ``i``'s weight (with optional bias)."""
qw = self.expert_weight(i)
bias = None
if self.bias is not None:
bias = cast_to_input(self.bias[i], input, copy=False)
if isinstance(qw, QuantizedTensor):
use_fast = (
not self._full_precision_mm
and qw.layout_cls.supports_fast_matmul()
and input.dim() == 2
)
if use_fast:
qin = QuantizedTensor.from_float(input, self.layout_type)
return torch.nn.functional.linear(qin, qw, bias)
out = input @ qw.dequantize().t()
return out + bias if bias is not None else out
return torch.nn.functional.linear(input, qw, bias)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = destination if destination is not None else {}
if self.bias is not None:
sd[f"{prefix}bias"] = self.bias
if self.quant_format is None:
if self.weight is not None:
sd[f"{prefix}weight"] = self.weight
return sd
sd[f"{prefix}weight"] = self._qdata
if self.quant_format == "nvfp4":
sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8)
sd[f"{prefix}weight_scale_2"] = self._tensor_scale
elif self.quant_format == "mxfp8":
sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8)
elif self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
sd[f"{prefix}weight_scale"] = self._tensor_scale
quant_conf = {"format": self.quant_format, "num_experts": self.num_experts}
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd[f"{prefix}comfy_quant"] = torch.tensor(
list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8
)
return sd
class Embedding(manual_cast.Embedding): class Embedding(manual_cast.Embedding):
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs): strict, missing_keys, unexpected_keys, error_msgs):

View File

@ -1366,10 +1366,7 @@ def detect_te_model(sd):
return TEModel.GEMMA_3_4B return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B return TEModel.GEMMA_2_2B
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512). # Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
if "model.layers.0.self_attn.sinks" in sd and ( if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd:
"model.layers.0.mlp.experts.gate_up_proj" in sd
or "model.layers.0.mlp.experts.gate_up_proj_blocks" in sd
):
return TEModel.GPT_OSS_20B return TEModel.GPT_OSS_20B
if 'model.layers.0.self_attn.k_proj.bias' in sd: if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias'] weight = sd['model.layers.0.self_attn.k_proj.bias']
@ -1554,8 +1551,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.GPT_OSS_20B: elif te_model == TEModel.GPT_OSS_20B:
mxfp4 = any("model.layers.0.mlp.experts.gate_up_proj_blocks" in sd for sd in clip_data) clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data))
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data), mxfp4_runtime=mxfp4)
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.QWEN3_4B: elif te_model == TEModel.QWEN3_4B:

View File

@ -859,10 +859,8 @@ class Lens(supported_models_base.BASE):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
for hint in ("gpt_oss.transformer.", ""): for hint in ("gpt_oss.transformer.", ""):
full_prefix = "{}{}".format(pref, hint) full_prefix = "{}{}".format(pref, hint)
if "{}model.layers.0.self_attn.sinks".format(full_prefix) in state_dict: if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict:
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix) detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
if "{}model.layers.0.mlp.experts.gate_up_proj_blocks".format(full_prefix) in state_dict:
detect["mxfp4_runtime"] = True
return supported_models_base.ClipTarget( return supported_models_base.ClipTarget(
comfy.text_encoders.gpt_oss.LensTokenizer, comfy.text_encoders.gpt_oss.LensTokenizer,
comfy.text_encoders.gpt_oss.lens_te(**detect), comfy.text_encoders.gpt_oss.lens_te(**detect),

View File

@ -2,11 +2,9 @@
from __future__ import annotations from __future__ import annotations
import gc
import logging
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple from typing import Any, List, Optional, Sequence
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -201,7 +199,7 @@ class GptOssTopKRouter(nn.Module):
class GptOssExperts(nn.Module): class GptOssExperts(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None): def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__() super().__init__()
self.num_experts = config.num_local_experts self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -213,29 +211,8 @@ class GptOssExperts(nn.Module):
H = self.hidden_size H = self.hidden_size
I = self.intermediate_size I = self.intermediate_size
self.gate_up_proj_bias = nn.Parameter(torch.empty(E, 2 * I, device=device, dtype=dtype)) self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
self.down_proj_bias = nn.Parameter(torch.empty(E, H, device=device, dtype=dtype)) self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)
self.gate_up_proj = nn.Parameter(torch.empty(E, H, 2 * I, device=device, dtype=dtype))
self.down_proj = nn.Parameter(torch.empty(E, I, H, device=device, dtype=dtype))
def switch_to_mxfp4(self, device=None):
"""Swap bf16 weight Parameters for uint8 MXFP4 packed buffers.
On-disk MXFP4 layout: ``[E, 2*I, G_up, 16]`` uint8 + ``[E, 2*I, G_up]``
uint8 (E8M0) for ``gate_up``; ``[E, H, G_down, 16]`` + ``[E, H, G_down]``
for ``down``. ``G_up * 32 = H``, ``G_down * 32 = I``.
"""
E, H, I = self.num_experts, self.hidden_size, self.intermediate_size
if H % 32 != 0 or I % 32 != 0:
raise ValueError(f"MXFP4 requires H, I divisible by 32; got H={H}, I={I}")
del self.gate_up_proj
del self.down_proj
G_up = H // 32
G_down = I // 32
self.register_buffer("gate_up_proj_blocks", torch.empty(E, 2 * I, G_up, 16, dtype=torch.uint8, device=device))
self.register_buffer("gate_up_proj_scales", torch.empty(E, 2 * I, G_up, dtype=torch.uint8, device=device))
self.register_buffer("down_proj_blocks", torch.empty(E, H, G_down, 16, dtype=torch.uint8, device=device))
self.register_buffer("down_proj_scales", torch.empty(E, H, G_down, dtype=torch.uint8, device=device))
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
gate = gate_up[..., ::2] gate = gate_up[..., ::2]
@ -245,24 +222,6 @@ class GptOssExperts(nn.Module):
glu = gate * torch.sigmoid(gate * self.alpha) glu = gate * torch.sigmoid(gate * self.alpha)
return (up + 1) * glu return (up + 1) * glu
@staticmethod
def _dequant_one_expert(
blocks_e: torch.Tensor,
scales_e: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""Dequant one expert's MXFP4 ``[D, G, 16]`` + ``[D, G]`` to ``[G*32, D]``."""
D, G, B = blocks_e.shape
val_per_row = G * 32
lut = _fp4_lut(dtype, blocks_e.device)
blocks_flat = blocks_e.reshape(D * G, B)
scales_flat = (scales_e.to(torch.int32) - 127).reshape(D * G, 1)
dec = torch.empty(D * G, B * 2, dtype=dtype, device=blocks_e.device)
dec[:, 0::2] = lut[(blocks_flat & 0x0F).to(torch.long)]
dec[:, 1::2] = lut[(blocks_flat >> 4).to(torch.long)]
torch.ldexp(dec, scales_flat, out=dec)
return dec.view(D, val_per_row).transpose(0, 1)
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
N = hidden_states.shape[0] N = hidden_states.shape[0]
top_k = router_indices.shape[-1] top_k = router_indices.shape[-1]
@ -273,34 +232,15 @@ class GptOssExperts(nn.Module):
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
is_mxfp4 = hasattr(self, "gate_up_proj_blocks")
for ei in expert_hit: for ei in expert_hit:
expert_idx = int(ei.item()) expert_idx = int(ei.item())
top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current = hidden_states[token_idx] current = hidden_states[token_idx]
if is_mxfp4: gate_up = self.gate_up_proj.expert_linear(current, expert_idx)
gate_up_w = self._dequant_one_expert(
self.gate_up_proj_blocks[expert_idx],
self.gate_up_proj_scales[expert_idx],
current.dtype,
)
down_w = self._dequant_one_expert(
self.down_proj_blocks[expert_idx],
self.down_proj_scales[expert_idx],
current.dtype,
)
else:
gate_up_w = comfy.ops.cast_to_input(self.gate_up_proj[expert_idx], current, copy=False)
down_w = comfy.ops.cast_to_input(self.down_proj[expert_idx], current, copy=False)
gate_up_b = comfy.ops.cast_to_input(self.gate_up_proj_bias[expert_idx], current, copy=False)
down_b = comfy.ops.cast_to_input(self.down_proj_bias[expert_idx], current, copy=False)
gate_up = current @ gate_up_w + gate_up_b
gated = self._apply_gate(gate_up) gated = self._apply_gate(gate_up)
expert_out = gated @ down_w + down_b expert_out = self.down_proj.expert_linear(gated, expert_idx)
weighted = expert_out * routing_weights[token_idx, top_k_pos, None] weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
flat_idx = token_idx * top_k + top_k_pos flat_idx = token_idx * top_k + top_k_pos
@ -310,10 +250,10 @@ class GptOssExperts(nn.Module):
class GptOssMLP(nn.Module): class GptOssMLP(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None): def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__() super().__init__()
self.router = GptOssTopKRouter(config, device=device, dtype=dtype) self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
self.experts = GptOssExperts(config, device=device, dtype=dtype) self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
B, S, H = hidden_states.shape B, S, H = hidden_states.shape
@ -329,7 +269,7 @@ class GptOssDecoderLayer(nn.Module):
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None): def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
super().__init__() super().__init__()
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops) self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
self.mlp = GptOssMLP(config, device=device, dtype=dtype) self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.layer_type = config.layer_types[layer_idx] self.layer_type = config.layer_types[layer_idx]
@ -579,125 +519,24 @@ class LensTokenizer(sd1_clip.SD1Tokenizer):
) )
# MXFP4 E2M1 LUT (1 sign + 2 exp + 1 mantissa).
_FP4_VALUES: Tuple[float, ...] = (
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
)
_FP4_LUT_CACHE: Dict[Tuple[torch.dtype, str], torch.Tensor] = {}
def _fp4_lut(dtype: torch.dtype, device: torch.device) -> torch.Tensor:
"""Cached per (dtype, device) FP4 lookup table — avoids per-call allocation."""
key = (dtype, str(device))
lut = _FP4_LUT_CACHE.get(key)
if lut is None:
lut = torch.tensor(_FP4_VALUES, dtype=dtype, device=device)
_FP4_LUT_CACHE[key] = lut
return lut
def _safe_dequant_moe_tensor(
blocks: torch.Tensor,
scales: torch.Tensor,
*,
dtype: torch.dtype = torch.bfloat16,
rows_per_chunk: int = 4096,
) -> torch.Tensor:
"""Eager full-tensor MXFP4 dequant -> ``[E, H, 2*I]`` ``dtype``.
Allocates the output in its final transposed layout and writes in chunks
"""
blocks = blocks.to(torch.uint8)
scales = scales.to(torch.int32) - 127
assert blocks.shape[:-1] == scales.shape, (
f"{blocks.shape[:-1]=} does not match {scales.shape=}"
)
*prefix_shape, G, B = blocks.shape
if len(prefix_shape) != 2:
raise ValueError(f"expected 2-D prefix (E, 2*I); got {prefix_shape}")
E, D = prefix_shape
val_per_row = G * B * 2 # this is H after dequant
rows_total = E * D * G
blocks = blocks.reshape(rows_total, B)
scales = scales.reshape(rows_total, 1)
lut = _fp4_lut(dtype, blocks.device)
out = torch.empty(E, val_per_row, D, dtype=dtype, device=blocks.device)
for e in range(E):
for d0 in range(0, D, rows_per_chunk):
d1 = min(d0 + rows_per_chunk, D)
r0 = e * D * G + d0 * G
r1 = e * D * G + d1 * G
blk = blocks[r0:r1]
exp = scales[r0:r1]
dec = torch.empty((d1 - d0) * G, B * 2, dtype=dtype, device=blocks.device)
dec[:, 0::2] = lut[(blk & 0x0F).to(torch.long)]
dec[:, 1::2] = lut[(blk >> 4).to(torch.long)]
torch.ldexp(dec, exp, out=dec)
out[e, :, d0:d1] = dec.view(d1 - d0, val_per_row).transpose(0, 1)
del blk, exp, dec
return out
def _dequant_mxfp4_state_dict(sd: Dict[str, torch.Tensor], target_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
"""Eager-dequant every ``*_blocks``/``*_scales`` pair in ``sd`` in place."""
pairs: List[Tuple[str, str, str]] = []
for k in list(sd.keys()):
if k.endswith("_blocks"):
stem = k[: -len("_blocks")]
sk = stem + "_scales"
if sk in sd:
pairs.append((stem, k, sk))
if not pairs:
return sd
logging.info("Lens: dequantizing %d MXFP4 expert tensors -> %s", len(pairs), target_dtype)
for stem, bk, sk in pairs:
blocks = sd.pop(bk)
scales = sd.pop(sk)
sd[stem] = _safe_dequant_moe_tensor(blocks, scales, dtype=target_dtype)
del blocks, scales
gc.collect()
return sd
class LensGptOssClipModel(nn.Module): class LensGptOssClipModel(nn.Module):
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor).""" """SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
def __init__(self, device="cpu", dtype=None, model_options=None, **_): def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
super().__init__() super().__init__()
model_options = dict(model_options or {}) model_options = dict(model_options or {})
operations = model_options.get("custom_operations") operations = model_options.get("custom_operations")
quant_config = model_options.get("quantization_metadata")
if operations is None: if operations is None:
if quant_config is not None: quant_config = model_options.get("quantization_metadata") or {}
operations = comfy.ops.mixed_precision_ops( operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
quant_config, dtype, full_precision_mm=True
)
else:
operations = comfy.ops.manual_cast
self.operations = operations self.operations = operations
cfg_overrides = model_options.get("gpt_oss_config", {}) cfg_overrides = model_options.get("gpt_oss_config", {})
self.config = GptOss20BConfig(**cfg_overrides) self.config = GptOss20BConfig(**cfg_overrides)
self.selected_layers = tuple( self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
model_options.get("selected_layers", LENS_SELECTED_LAYERS)
)
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET)) self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
# mxfp4_runtime=True keeps experts packed and dequants per hit at forward.
self.mxfp4_runtime = bool(model_options.get("mxfp4_runtime", False))
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations) self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
self.num_layers = self.config.num_hidden_layers self.num_layers = self.config.num_hidden_layers
self.dtype = dtype self.dtype = dtype
@ -755,17 +594,6 @@ class LensGptOssClipModel(nn.Module):
return flat, None, extra return flat, None, extra
def load_sd(self, sd): def load_sd(self, sd):
if any(k.startswith("model.") for k in sd):
sd = {(k[len("model."):] if k.startswith("model.") else k): v for k, v in sd.items()}
sd.pop("lm_head.weight", None)
if self.mxfp4_runtime:
device = next(self.transformer.parameters()).device
for layer in self.transformer.layers:
layer.mlp.experts.switch_to_mxfp4(device=device)
else:
sd = _dequant_mxfp4_state_dict(sd, self.dtype)
return self.transformer.load_state_dict(sd, strict=False, assign=True) return self.transformer.load_state_dict(sd, strict=False, assign=True)
@ -774,7 +602,7 @@ class LensTEModel(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {}) super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
def lens_te(dtype_llama=None, llama_quantization_metadata=None, mxfp4_runtime=False): def lens_te(dtype_llama=None, llama_quantization_metadata=None):
class LensTEModel_(LensTEModel): class LensTEModel_(LensTEModel):
def __init__(self, device="cpu", dtype=None, model_options=None): def __init__(self, device="cpu", dtype=None, model_options=None):
mo = dict(model_options or {}) mo = dict(model_options or {})
@ -782,8 +610,6 @@ def lens_te(dtype_llama=None, llama_quantization_metadata=None, mxfp4_runtime=Fa
mo["quantization_metadata"] = llama_quantization_metadata mo["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None: if dtype_llama is not None:
dtype = dtype_llama dtype = dtype_llama
if mxfp4_runtime:
mo["mxfp4_runtime"] = True
super().__init__(device=device, dtype=dtype, model_options=mo) super().__init__(device=device, dtype=dtype, model_options=mo)
return LensTEModel_ return LensTEModel_