diff --git a/comfy/ops.py b/comfy/ops.py index 9bcd6c900..3ee7c1216 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1333,6 +1333,204 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self._buffers[key] = fn(buf) return self + class MoEExperts(torch.nn.Module): + """Container for E quantized expert weights, indexed via ``expert_weight(i)``. + + Holds expert weights as 3D buffers/parameters. + + State-dict layout (analogous to ``mixed_precision_ops.Linear`` with a + leading expert dim — exact storage shape is layout-specific):: + + {prefix}.weight quant data (storage_t), leading dim = E + {prefix}.weight_scale block / per-tensor scale + {prefix}.weight_scale_2 [E] or scalar NVFP4 only + {prefix}.bias [E, out_features] optional, bf16 + {prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}} + + Without ``comfy_quant`` the weight loads as a plain bf16 3D Parameter ``[E, out, in]``. + """ + + def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): + super().__init__() + self.num_experts = num_experts + self.in_features = in_features + self.out_features = out_features + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + if bias: + self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) + + # Populated by _load_from_state_dict: + self.weight = None # bf16 fallback: 3D Parameter [E, out, in] + self.quant_format = None + self.layout_type = None + self._full_precision_mm = MixedPrecisionOps._full_precision_mm + self._full_precision_mm_config = False + + def reset_parameters(self): + return None + + def _load_scale_param(self, state_dict, prefix, param_name, device, + manually_loaded_keys, dtype=None): + key = f"{prefix}{param_name}" + value = state_dict.pop(key, None) + if value is not None: + value = value.to(device=device) + if dtype is not None: + value = value.view(dtype=dtype) + manually_loaded_keys.append(key) + return value + + # TODO: refactor to share more code with Linear._load_from_state_dict + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip(".") + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + logging.warning(f"Missing weight for MoEExperts layer {layer_name}") + return + manually_loaded_keys = [weight_key] + + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + manually_loaded_keys.append(f"{prefix}comfy_quant") + + if layer_conf is None: + self.weight = torch.nn.Parameter( + weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), + requires_grad=False, + ) + else: + self.quant_format = layer_conf.get("format") + self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) + if not self._full_precision_mm: + self._full_precision_mm = self._full_precision_mm_config + + if self.quant_format in MixedPrecisionOps._disabled: + self._full_precision_mm = True + + if self.quant_format is None: + raise ValueError(f"Unknown quant format for MoEExperts layer {layer_name}") + + qconfig = QUANT_ALGOS[self.quant_format] + self.layout_type = qconfig["comfy_tensor_layout"] + + if self.quant_format in ("float8_e4m3fn", "float8_e5m2"): + ts = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) + self.register_buffer("_tensor_scale", ts, persistent=False) + elif self.quant_format == "mxfp8": + bs = self._load_scale_param(state_dict, prefix, "weight_scale", device, + manually_loaded_keys, dtype=torch.uint8) + if bs is None: + raise ValueError(f"Missing MXFP8 block scales for MoEExperts layer {layer_name}") + self.register_buffer("_block_scale", bs.view(torch.float8_e8m0fnu), persistent=False) + elif self.quant_format == "nvfp4": + ts = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) + bs = self._load_scale_param(state_dict, prefix, "weight_scale", device, + manually_loaded_keys, dtype=torch.float8_e4m3fn) + if ts is None or bs is None: + raise ValueError(f"Missing NVFP4 scales for MoEExperts layer {layer_name}") + self.register_buffer("_tensor_scale", ts, persistent=False) + self.register_buffer("_block_scale", bs, persistent=False) + else: + raise ValueError(f"Unsupported MoEExperts quant format: {self.quant_format}") + + self.register_buffer( + "_qdata", + weight.to(device=device, dtype=qconfig["storage_t"]), + persistent=False, + ) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + for k in manually_loaded_keys: + if k in missing_keys: + missing_keys.remove(k) + + def expert_weight(self, i: int): + """Expert i's weight (Tensor or QuantizedTensor).""" + if self.quant_format is None: + return self.weight[i] + + qdata = self._qdata[i] + layout_cls = get_layout_class(self.layout_type) + orig_shape = (self.out_features, self.in_features) + + if self.quant_format in ("float8_e4m3fn", "float8_e5m2"): + scale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale + params = layout_cls.Params( + scale=scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=orig_shape, + ) + elif self.quant_format == "mxfp8": + params = layout_cls.Params( + scale=self._block_scale[i], + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=orig_shape, + ) + elif self.quant_format == "nvfp4": + tscale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale + params = layout_cls.Params( + scale=tscale, + block_scale=self._block_scale[i], + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=orig_shape, + ) + else: + raise ValueError(f"Unsupported quant format: {self.quant_format}") + return QuantizedTensor(qdata, self.layout_type, params) + + def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor: + """Linear against expert ``i``'s weight (with optional bias).""" + qw = self.expert_weight(i) + bias = None + if self.bias is not None: + bias = cast_to_input(self.bias[i], input, copy=False) + + if isinstance(qw, QuantizedTensor): + use_fast = ( + not self._full_precision_mm + and qw.layout_cls.supports_fast_matmul() + and input.dim() == 2 + ) + if use_fast: + qin = QuantizedTensor.from_float(input, self.layout_type) + return torch.nn.functional.linear(qin, qw, bias) + out = input @ qw.dequantize().t() + return out + bias if bias is not None else out + + return torch.nn.functional.linear(input, qw, bias) + + def state_dict(self, *args, destination=None, prefix="", **kwargs): + sd = destination if destination is not None else {} + if self.bias is not None: + sd[f"{prefix}bias"] = self.bias + if self.quant_format is None: + if self.weight is not None: + sd[f"{prefix}weight"] = self.weight + return sd + + sd[f"{prefix}weight"] = self._qdata + if self.quant_format == "nvfp4": + sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8) + sd[f"{prefix}weight_scale_2"] = self._tensor_scale + elif self.quant_format == "mxfp8": + sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8) + elif self.quant_format in ("float8_e4m3fn", "float8_e5m2"): + sd[f"{prefix}weight_scale"] = self._tensor_scale + + quant_conf = {"format": self.quant_format, "num_experts": self.num_experts} + if self._full_precision_mm_config: + quant_conf["full_precision_matrix_mult"] = True + sd[f"{prefix}comfy_quant"] = torch.tensor( + list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8 + ) + return sd + class Embedding(manual_cast.Embedding): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): diff --git a/comfy/sd.py b/comfy/sd.py index b18f32dec..446ac381d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1366,10 +1366,7 @@ def detect_te_model(sd): return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B # Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512). - if "model.layers.0.self_attn.sinks" in sd and ( - "model.layers.0.mlp.experts.gate_up_proj" in sd - or "model.layers.0.mlp.experts.gate_up_proj_blocks" in sd - ): + if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd: return TEModel.GPT_OSS_20B if 'model.layers.0.self_attn.k_proj.bias' in sd: weight = sd['model.layers.0.self_attn.k_proj.bias'] @@ -1554,8 +1551,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) elif te_model == TEModel.GPT_OSS_20B: - mxfp4 = any("model.layers.0.mlp.experts.gate_up_proj_blocks" in sd for sd in clip_data) - clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data), mxfp4_runtime=mxfp4) + clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.QWEN3_4B: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c5fb5e5cd..e451892e9 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -859,10 +859,8 @@ class Lens(supported_models_base.BASE): pref = self.text_encoder_key_prefix[0] for hint in ("gpt_oss.transformer.", ""): full_prefix = "{}{}".format(pref, hint) - if "{}model.layers.0.self_attn.sinks".format(full_prefix) in state_dict: + if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict: detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix) - if "{}model.layers.0.mlp.experts.gate_up_proj_blocks".format(full_prefix) in state_dict: - detect["mxfp4_runtime"] = True return supported_models_base.ClipTarget( comfy.text_encoders.gpt_oss.LensTokenizer, comfy.text_encoders.gpt_oss.lens_te(**detect), diff --git a/comfy/text_encoders/gpt_oss.py b/comfy/text_encoders/gpt_oss.py index 7c390dd3e..cf532e87a 100644 --- a/comfy/text_encoders/gpt_oss.py +++ b/comfy/text_encoders/gpt_oss.py @@ -2,11 +2,9 @@ from __future__ import annotations -import gc -import logging import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence import torch import torch.nn as nn @@ -201,7 +199,7 @@ class GptOssTopKRouter(nn.Module): class GptOssExperts(nn.Module): - def __init__(self, config: GptOss20BConfig, device=None, dtype=None): + def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None): super().__init__() self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size @@ -213,29 +211,8 @@ class GptOssExperts(nn.Module): H = self.hidden_size I = self.intermediate_size - self.gate_up_proj_bias = nn.Parameter(torch.empty(E, 2 * I, device=device, dtype=dtype)) - self.down_proj_bias = nn.Parameter(torch.empty(E, H, device=device, dtype=dtype)) - self.gate_up_proj = nn.Parameter(torch.empty(E, H, 2 * I, device=device, dtype=dtype)) - self.down_proj = nn.Parameter(torch.empty(E, I, H, device=device, dtype=dtype)) - - def switch_to_mxfp4(self, device=None): - """Swap bf16 weight Parameters for uint8 MXFP4 packed buffers. - - On-disk MXFP4 layout: ``[E, 2*I, G_up, 16]`` uint8 + ``[E, 2*I, G_up]`` - uint8 (E8M0) for ``gate_up``; ``[E, H, G_down, 16]`` + ``[E, H, G_down]`` - for ``down``. ``G_up * 32 = H``, ``G_down * 32 = I``. - """ - E, H, I = self.num_experts, self.hidden_size, self.intermediate_size - if H % 32 != 0 or I % 32 != 0: - raise ValueError(f"MXFP4 requires H, I divisible by 32; got H={H}, I={I}") - del self.gate_up_proj - del self.down_proj - G_up = H // 32 - G_down = I // 32 - self.register_buffer("gate_up_proj_blocks", torch.empty(E, 2 * I, G_up, 16, dtype=torch.uint8, device=device)) - self.register_buffer("gate_up_proj_scales", torch.empty(E, 2 * I, G_up, dtype=torch.uint8, device=device)) - self.register_buffer("down_proj_blocks", torch.empty(E, H, G_down, 16, dtype=torch.uint8, device=device)) - self.register_buffer("down_proj_scales", torch.empty(E, H, G_down, dtype=torch.uint8, device=device)) + self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype) + self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype) def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: gate = gate_up[..., ::2] @@ -245,24 +222,6 @@ class GptOssExperts(nn.Module): glu = gate * torch.sigmoid(gate * self.alpha) return (up + 1) * glu - @staticmethod - def _dequant_one_expert( - blocks_e: torch.Tensor, - scales_e: torch.Tensor, - dtype: torch.dtype, - ) -> torch.Tensor: - """Dequant one expert's MXFP4 ``[D, G, 16]`` + ``[D, G]`` to ``[G*32, D]``.""" - D, G, B = blocks_e.shape - val_per_row = G * 32 - lut = _fp4_lut(dtype, blocks_e.device) - blocks_flat = blocks_e.reshape(D * G, B) - scales_flat = (scales_e.to(torch.int32) - 127).reshape(D * G, 1) - dec = torch.empty(D * G, B * 2, dtype=dtype, device=blocks_e.device) - dec[:, 0::2] = lut[(blocks_flat & 0x0F).to(torch.long)] - dec[:, 1::2] = lut[(blocks_flat >> 4).to(torch.long)] - torch.ldexp(dec, scales_flat, out=dec) - return dec.view(D, val_per_row).transpose(0, 1) - def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: N = hidden_states.shape[0] top_k = router_indices.shape[-1] @@ -273,34 +232,15 @@ class GptOssExperts(nn.Module): expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - is_mxfp4 = hasattr(self, "gate_up_proj_blocks") - for ei in expert_hit: expert_idx = int(ei.item()) top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current = hidden_states[token_idx] - if is_mxfp4: - gate_up_w = self._dequant_one_expert( - self.gate_up_proj_blocks[expert_idx], - self.gate_up_proj_scales[expert_idx], - current.dtype, - ) - down_w = self._dequant_one_expert( - self.down_proj_blocks[expert_idx], - self.down_proj_scales[expert_idx], - current.dtype, - ) - else: - gate_up_w = comfy.ops.cast_to_input(self.gate_up_proj[expert_idx], current, copy=False) - down_w = comfy.ops.cast_to_input(self.down_proj[expert_idx], current, copy=False) - - gate_up_b = comfy.ops.cast_to_input(self.gate_up_proj_bias[expert_idx], current, copy=False) - down_b = comfy.ops.cast_to_input(self.down_proj_bias[expert_idx], current, copy=False) - - gate_up = current @ gate_up_w + gate_up_b + gate_up = self.gate_up_proj.expert_linear(current, expert_idx) gated = self._apply_gate(gate_up) - expert_out = gated @ down_w + down_b + expert_out = self.down_proj.expert_linear(gated, expert_idx) + weighted = expert_out * routing_weights[token_idx, top_k_pos, None] flat_idx = token_idx * top_k + top_k_pos @@ -310,10 +250,10 @@ class GptOssExperts(nn.Module): class GptOssMLP(nn.Module): - def __init__(self, config: GptOss20BConfig, device=None, dtype=None): + def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None): super().__init__() self.router = GptOssTopKRouter(config, device=device, dtype=dtype) - self.experts = GptOssExperts(config, device=device, dtype=dtype) + self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: B, S, H = hidden_states.shape @@ -329,7 +269,7 @@ class GptOssDecoderLayer(nn.Module): def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None): super().__init__() self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops) - self.mlp = GptOssMLP(config, device=device, dtype=dtype) + self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) self.layer_type = config.layer_types[layer_idx] @@ -579,125 +519,24 @@ class LensTokenizer(sd1_clip.SD1Tokenizer): ) -# MXFP4 E2M1 LUT (1 sign + 2 exp + 1 mantissa). -_FP4_VALUES: Tuple[float, ...] = ( - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, -) - - -_FP4_LUT_CACHE: Dict[Tuple[torch.dtype, str], torch.Tensor] = {} - - -def _fp4_lut(dtype: torch.dtype, device: torch.device) -> torch.Tensor: - """Cached per (dtype, device) FP4 lookup table — avoids per-call allocation.""" - key = (dtype, str(device)) - lut = _FP4_LUT_CACHE.get(key) - if lut is None: - lut = torch.tensor(_FP4_VALUES, dtype=dtype, device=device) - _FP4_LUT_CACHE[key] = lut - return lut - - -def _safe_dequant_moe_tensor( - blocks: torch.Tensor, - scales: torch.Tensor, - *, - dtype: torch.dtype = torch.bfloat16, - rows_per_chunk: int = 4096, -) -> torch.Tensor: - """Eager full-tensor MXFP4 dequant -> ``[E, H, 2*I]`` ``dtype``. - - Allocates the output in its final transposed layout and writes in chunks - """ - blocks = blocks.to(torch.uint8) - scales = scales.to(torch.int32) - 127 - - assert blocks.shape[:-1] == scales.shape, ( - f"{blocks.shape[:-1]=} does not match {scales.shape=}" - ) - *prefix_shape, G, B = blocks.shape - if len(prefix_shape) != 2: - raise ValueError(f"expected 2-D prefix (E, 2*I); got {prefix_shape}") - - E, D = prefix_shape - val_per_row = G * B * 2 # this is H after dequant - - rows_total = E * D * G - blocks = blocks.reshape(rows_total, B) - scales = scales.reshape(rows_total, 1) - - lut = _fp4_lut(dtype, blocks.device) - out = torch.empty(E, val_per_row, D, dtype=dtype, device=blocks.device) - - for e in range(E): - for d0 in range(0, D, rows_per_chunk): - d1 = min(d0 + rows_per_chunk, D) - r0 = e * D * G + d0 * G - r1 = e * D * G + d1 * G - blk = blocks[r0:r1] - exp = scales[r0:r1] - dec = torch.empty((d1 - d0) * G, B * 2, dtype=dtype, device=blocks.device) - dec[:, 0::2] = lut[(blk & 0x0F).to(torch.long)] - dec[:, 1::2] = lut[(blk >> 4).to(torch.long)] - torch.ldexp(dec, exp, out=dec) - out[e, :, d0:d1] = dec.view(d1 - d0, val_per_row).transpose(0, 1) - del blk, exp, dec - return out - - -def _dequant_mxfp4_state_dict(sd: Dict[str, torch.Tensor], target_dtype: torch.dtype) -> Dict[str, torch.Tensor]: - """Eager-dequant every ``*_blocks``/``*_scales`` pair in ``sd`` in place.""" - pairs: List[Tuple[str, str, str]] = [] - for k in list(sd.keys()): - if k.endswith("_blocks"): - stem = k[: -len("_blocks")] - sk = stem + "_scales" - if sk in sd: - pairs.append((stem, k, sk)) - - if not pairs: - return sd - - logging.info("Lens: dequantizing %d MXFP4 expert tensors -> %s", len(pairs), target_dtype) - for stem, bk, sk in pairs: - blocks = sd.pop(bk) - scales = sd.pop(sk) - sd[stem] = _safe_dequant_moe_tensor(blocks, scales, dtype=target_dtype) - del blocks, scales - - gc.collect() - return sd - - class LensGptOssClipModel(nn.Module): """SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor).""" - def __init__(self, device="cpu", dtype=None, model_options=None, **_): + def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs): super().__init__() model_options = dict(model_options or {}) operations = model_options.get("custom_operations") - quant_config = model_options.get("quantization_metadata") if operations is None: - if quant_config is not None: - operations = comfy.ops.mixed_precision_ops( - quant_config, dtype, full_precision_mm=True - ) - else: - operations = comfy.ops.manual_cast + quant_config = model_options.get("quantization_metadata") or {} + operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True) self.operations = operations cfg_overrides = model_options.get("gpt_oss_config", {}) self.config = GptOss20BConfig(**cfg_overrides) - self.selected_layers = tuple( - model_options.get("selected_layers", LENS_SELECTED_LAYERS) - ) + self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS)) self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET)) - # mxfp4_runtime=True keeps experts packed and dequants per hit at forward. - self.mxfp4_runtime = bool(model_options.get("mxfp4_runtime", False)) - self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations) self.num_layers = self.config.num_hidden_layers self.dtype = dtype @@ -755,17 +594,6 @@ class LensGptOssClipModel(nn.Module): return flat, None, extra def load_sd(self, sd): - if any(k.startswith("model.") for k in sd): - sd = {(k[len("model."):] if k.startswith("model.") else k): v for k, v in sd.items()} - sd.pop("lm_head.weight", None) - - if self.mxfp4_runtime: - device = next(self.transformer.parameters()).device - for layer in self.transformer.layers: - layer.mlp.experts.switch_to_mxfp4(device=device) - else: - sd = _dequant_mxfp4_state_dict(sd, self.dtype) - return self.transformer.load_state_dict(sd, strict=False, assign=True) @@ -774,7 +602,7 @@ class LensTEModel(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {}) -def lens_te(dtype_llama=None, llama_quantization_metadata=None, mxfp4_runtime=False): +def lens_te(dtype_llama=None, llama_quantization_metadata=None): class LensTEModel_(LensTEModel): def __init__(self, device="cpu", dtype=None, model_options=None): mo = dict(model_options or {}) @@ -782,8 +610,6 @@ def lens_te(dtype_llama=None, llama_quantization_metadata=None, mxfp4_runtime=Fa mo["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama - if mxfp4_runtime: - mo["mxfp4_runtime"] = True super().__init__(device=device, dtype=dtype, model_options=mo) return LensTEModel_