mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
Pivot from using mxfp4 to nvfp4
This commit is contained in:
parent
b429028ad2
commit
c431fd555b
198
comfy/ops.py
198
comfy/ops.py
@ -1333,6 +1333,204 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
|
||||
class MoEExperts(torch.nn.Module):
|
||||
"""Container for E quantized expert weights, indexed via ``expert_weight(i)``.
|
||||
|
||||
Holds expert weights as 3D buffers/parameters.
|
||||
|
||||
State-dict layout (analogous to ``mixed_precision_ops.Linear`` with a
|
||||
leading expert dim — exact storage shape is layout-specific)::
|
||||
|
||||
{prefix}.weight quant data (storage_t), leading dim = E
|
||||
{prefix}.weight_scale block / per-tensor scale
|
||||
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
|
||||
{prefix}.bias [E, out_features] optional, bf16
|
||||
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
|
||||
|
||||
Without ``comfy_quant`` the weight loads as a plain bf16 3D Parameter ``[E, out, in]``.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# Populated by _load_from_state_dict:
|
||||
self.weight = None # bf16 fallback: 3D Parameter [E, out, in]
|
||||
self.quant_format = None
|
||||
self.layout_type = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
self._full_precision_mm_config = False
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _load_scale_param(self, state_dict, prefix, param_name, device,
|
||||
manually_loaded_keys, dtype=None):
|
||||
key = f"{prefix}{param_name}"
|
||||
value = state_dict.pop(key, None)
|
||||
if value is not None:
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return value
|
||||
|
||||
# TODO: refactor to share more code with Linear._load_from_state_dict
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip(".")
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
logging.warning(f"Missing weight for MoEExperts layer {layer_name}")
|
||||
return
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
manually_loaded_keys.append(f"{prefix}comfy_quant")
|
||||
|
||||
if layer_conf is None:
|
||||
self.weight = torch.nn.Parameter(
|
||||
weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format")
|
||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
if not self._full_precision_mm:
|
||||
self._full_precision_mm = self._full_precision_mm_config
|
||||
|
||||
if self.quant_format in MixedPrecisionOps._disabled:
|
||||
self._full_precision_mm = True
|
||||
|
||||
if self.quant_format is None:
|
||||
raise ValueError(f"Unknown quant format for MoEExperts layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[self.quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
|
||||
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||
ts = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||
self.register_buffer("_tensor_scale", ts, persistent=False)
|
||||
elif self.quant_format == "mxfp8":
|
||||
bs = self._load_scale_param(state_dict, prefix, "weight_scale", device,
|
||||
manually_loaded_keys, dtype=torch.uint8)
|
||||
if bs is None:
|
||||
raise ValueError(f"Missing MXFP8 block scales for MoEExperts layer {layer_name}")
|
||||
self.register_buffer("_block_scale", bs.view(torch.float8_e8m0fnu), persistent=False)
|
||||
elif self.quant_format == "nvfp4":
|
||||
ts = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||
bs = self._load_scale_param(state_dict, prefix, "weight_scale", device,
|
||||
manually_loaded_keys, dtype=torch.float8_e4m3fn)
|
||||
if ts is None or bs is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for MoEExperts layer {layer_name}")
|
||||
self.register_buffer("_tensor_scale", ts, persistent=False)
|
||||
self.register_buffer("_block_scale", bs, persistent=False)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MoEExperts quant format: {self.quant_format}")
|
||||
|
||||
self.register_buffer(
|
||||
"_qdata",
|
||||
weight.to(device=device, dtype=qconfig["storage_t"]),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
for k in manually_loaded_keys:
|
||||
if k in missing_keys:
|
||||
missing_keys.remove(k)
|
||||
|
||||
def expert_weight(self, i: int):
|
||||
"""Expert i's weight (Tensor or QuantizedTensor)."""
|
||||
if self.quant_format is None:
|
||||
return self.weight[i]
|
||||
|
||||
qdata = self._qdata[i]
|
||||
layout_cls = get_layout_class(self.layout_type)
|
||||
orig_shape = (self.out_features, self.in_features)
|
||||
|
||||
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||
scale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale
|
||||
params = layout_cls.Params(
|
||||
scale=scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=orig_shape,
|
||||
)
|
||||
elif self.quant_format == "mxfp8":
|
||||
params = layout_cls.Params(
|
||||
scale=self._block_scale[i],
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=orig_shape,
|
||||
)
|
||||
elif self.quant_format == "nvfp4":
|
||||
tscale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale
|
||||
params = layout_cls.Params(
|
||||
scale=tscale,
|
||||
block_scale=self._block_scale[i],
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=orig_shape,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quant format: {self.quant_format}")
|
||||
return QuantizedTensor(qdata, self.layout_type, params)
|
||||
|
||||
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
|
||||
"""Linear against expert ``i``'s weight (with optional bias)."""
|
||||
qw = self.expert_weight(i)
|
||||
bias = None
|
||||
if self.bias is not None:
|
||||
bias = cast_to_input(self.bias[i], input, copy=False)
|
||||
|
||||
if isinstance(qw, QuantizedTensor):
|
||||
use_fast = (
|
||||
not self._full_precision_mm
|
||||
and qw.layout_cls.supports_fast_matmul()
|
||||
and input.dim() == 2
|
||||
)
|
||||
if use_fast:
|
||||
qin = QuantizedTensor.from_float(input, self.layout_type)
|
||||
return torch.nn.functional.linear(qin, qw, bias)
|
||||
out = input @ qw.dequantize().t()
|
||||
return out + bias if bias is not None else out
|
||||
|
||||
return torch.nn.functional.linear(input, qw, bias)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = destination if destination is not None else {}
|
||||
if self.bias is not None:
|
||||
sd[f"{prefix}bias"] = self.bias
|
||||
if self.quant_format is None:
|
||||
if self.weight is not None:
|
||||
sd[f"{prefix}weight"] = self.weight
|
||||
return sd
|
||||
|
||||
sd[f"{prefix}weight"] = self._qdata
|
||||
if self.quant_format == "nvfp4":
|
||||
sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8)
|
||||
sd[f"{prefix}weight_scale_2"] = self._tensor_scale
|
||||
elif self.quant_format == "mxfp8":
|
||||
sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8)
|
||||
elif self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||
sd[f"{prefix}weight_scale"] = self._tensor_scale
|
||||
|
||||
quant_conf = {"format": self.quant_format, "num_experts": self.num_experts}
|
||||
if self._full_precision_mm_config:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd[f"{prefix}comfy_quant"] = torch.tensor(
|
||||
list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8
|
||||
)
|
||||
return sd
|
||||
|
||||
class Embedding(manual_cast.Embedding):
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
@ -1366,10 +1366,7 @@ def detect_te_model(sd):
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
|
||||
if "model.layers.0.self_attn.sinks" in sd and (
|
||||
"model.layers.0.mlp.experts.gate_up_proj" in sd
|
||||
or "model.layers.0.mlp.experts.gate_up_proj_blocks" in sd
|
||||
):
|
||||
if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd:
|
||||
return TEModel.GPT_OSS_20B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
@ -1554,8 +1551,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
||||
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||
elif te_model == TEModel.GPT_OSS_20B:
|
||||
mxfp4 = any("model.layers.0.mlp.experts.gate_up_proj_blocks" in sd for sd in clip_data)
|
||||
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data), mxfp4_runtime=mxfp4)
|
||||
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.QWEN3_4B:
|
||||
|
||||
@ -859,10 +859,8 @@ class Lens(supported_models_base.BASE):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
for hint in ("gpt_oss.transformer.", ""):
|
||||
full_prefix = "{}{}".format(pref, hint)
|
||||
if "{}model.layers.0.self_attn.sinks".format(full_prefix) in state_dict:
|
||||
if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict:
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
|
||||
if "{}model.layers.0.mlp.experts.gate_up_proj_blocks".format(full_prefix) in state_dict:
|
||||
detect["mxfp4_runtime"] = True
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.gpt_oss.LensTokenizer,
|
||||
comfy.text_encoders.gpt_oss.lens_te(**detect),
|
||||
|
||||
@ -2,11 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -201,7 +199,7 @@ class GptOssTopKRouter(nn.Module):
|
||||
|
||||
|
||||
class GptOssExperts(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -213,29 +211,8 @@ class GptOssExperts(nn.Module):
|
||||
H = self.hidden_size
|
||||
I = self.intermediate_size
|
||||
|
||||
self.gate_up_proj_bias = nn.Parameter(torch.empty(E, 2 * I, device=device, dtype=dtype))
|
||||
self.down_proj_bias = nn.Parameter(torch.empty(E, H, device=device, dtype=dtype))
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(E, H, 2 * I, device=device, dtype=dtype))
|
||||
self.down_proj = nn.Parameter(torch.empty(E, I, H, device=device, dtype=dtype))
|
||||
|
||||
def switch_to_mxfp4(self, device=None):
|
||||
"""Swap bf16 weight Parameters for uint8 MXFP4 packed buffers.
|
||||
|
||||
On-disk MXFP4 layout: ``[E, 2*I, G_up, 16]`` uint8 + ``[E, 2*I, G_up]``
|
||||
uint8 (E8M0) for ``gate_up``; ``[E, H, G_down, 16]`` + ``[E, H, G_down]``
|
||||
for ``down``. ``G_up * 32 = H``, ``G_down * 32 = I``.
|
||||
"""
|
||||
E, H, I = self.num_experts, self.hidden_size, self.intermediate_size
|
||||
if H % 32 != 0 or I % 32 != 0:
|
||||
raise ValueError(f"MXFP4 requires H, I divisible by 32; got H={H}, I={I}")
|
||||
del self.gate_up_proj
|
||||
del self.down_proj
|
||||
G_up = H // 32
|
||||
G_down = I // 32
|
||||
self.register_buffer("gate_up_proj_blocks", torch.empty(E, 2 * I, G_up, 16, dtype=torch.uint8, device=device))
|
||||
self.register_buffer("gate_up_proj_scales", torch.empty(E, 2 * I, G_up, dtype=torch.uint8, device=device))
|
||||
self.register_buffer("down_proj_blocks", torch.empty(E, H, G_down, 16, dtype=torch.uint8, device=device))
|
||||
self.register_buffer("down_proj_scales", torch.empty(E, H, G_down, dtype=torch.uint8, device=device))
|
||||
self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
|
||||
self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
|
||||
gate = gate_up[..., ::2]
|
||||
@ -245,24 +222,6 @@ class GptOssExperts(nn.Module):
|
||||
glu = gate * torch.sigmoid(gate * self.alpha)
|
||||
return (up + 1) * glu
|
||||
|
||||
@staticmethod
|
||||
def _dequant_one_expert(
|
||||
blocks_e: torch.Tensor,
|
||||
scales_e: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""Dequant one expert's MXFP4 ``[D, G, 16]`` + ``[D, G]`` to ``[G*32, D]``."""
|
||||
D, G, B = blocks_e.shape
|
||||
val_per_row = G * 32
|
||||
lut = _fp4_lut(dtype, blocks_e.device)
|
||||
blocks_flat = blocks_e.reshape(D * G, B)
|
||||
scales_flat = (scales_e.to(torch.int32) - 127).reshape(D * G, 1)
|
||||
dec = torch.empty(D * G, B * 2, dtype=dtype, device=blocks_e.device)
|
||||
dec[:, 0::2] = lut[(blocks_flat & 0x0F).to(torch.long)]
|
||||
dec[:, 1::2] = lut[(blocks_flat >> 4).to(torch.long)]
|
||||
torch.ldexp(dec, scales_flat, out=dec)
|
||||
return dec.view(D, val_per_row).transpose(0, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
|
||||
N = hidden_states.shape[0]
|
||||
top_k = router_indices.shape[-1]
|
||||
@ -273,34 +232,15 @@ class GptOssExperts(nn.Module):
|
||||
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
|
||||
is_mxfp4 = hasattr(self, "gate_up_proj_blocks")
|
||||
|
||||
for ei in expert_hit:
|
||||
expert_idx = int(ei.item())
|
||||
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current = hidden_states[token_idx]
|
||||
|
||||
if is_mxfp4:
|
||||
gate_up_w = self._dequant_one_expert(
|
||||
self.gate_up_proj_blocks[expert_idx],
|
||||
self.gate_up_proj_scales[expert_idx],
|
||||
current.dtype,
|
||||
)
|
||||
down_w = self._dequant_one_expert(
|
||||
self.down_proj_blocks[expert_idx],
|
||||
self.down_proj_scales[expert_idx],
|
||||
current.dtype,
|
||||
)
|
||||
else:
|
||||
gate_up_w = comfy.ops.cast_to_input(self.gate_up_proj[expert_idx], current, copy=False)
|
||||
down_w = comfy.ops.cast_to_input(self.down_proj[expert_idx], current, copy=False)
|
||||
|
||||
gate_up_b = comfy.ops.cast_to_input(self.gate_up_proj_bias[expert_idx], current, copy=False)
|
||||
down_b = comfy.ops.cast_to_input(self.down_proj_bias[expert_idx], current, copy=False)
|
||||
|
||||
gate_up = current @ gate_up_w + gate_up_b
|
||||
gate_up = self.gate_up_proj.expert_linear(current, expert_idx)
|
||||
gated = self._apply_gate(gate_up)
|
||||
expert_out = gated @ down_w + down_b
|
||||
expert_out = self.down_proj.expert_linear(gated, expert_idx)
|
||||
|
||||
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
|
||||
|
||||
flat_idx = token_idx * top_k + top_k_pos
|
||||
@ -310,10 +250,10 @@ class GptOssExperts(nn.Module):
|
||||
|
||||
|
||||
class GptOssMLP(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
|
||||
self.experts = GptOssExperts(config, device=device, dtype=dtype)
|
||||
self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
B, S, H = hidden_states.shape
|
||||
@ -329,7 +269,7 @@ class GptOssDecoderLayer(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = GptOssMLP(config, device=device, dtype=dtype)
|
||||
self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.layer_type = config.layer_types[layer_idx]
|
||||
@ -579,125 +519,24 @@ class LensTokenizer(sd1_clip.SD1Tokenizer):
|
||||
)
|
||||
|
||||
|
||||
# MXFP4 E2M1 LUT (1 sign + 2 exp + 1 mantissa).
|
||||
_FP4_VALUES: Tuple[float, ...] = (
|
||||
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
||||
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
|
||||
)
|
||||
|
||||
|
||||
_FP4_LUT_CACHE: Dict[Tuple[torch.dtype, str], torch.Tensor] = {}
|
||||
|
||||
|
||||
def _fp4_lut(dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
||||
"""Cached per (dtype, device) FP4 lookup table — avoids per-call allocation."""
|
||||
key = (dtype, str(device))
|
||||
lut = _FP4_LUT_CACHE.get(key)
|
||||
if lut is None:
|
||||
lut = torch.tensor(_FP4_VALUES, dtype=dtype, device=device)
|
||||
_FP4_LUT_CACHE[key] = lut
|
||||
return lut
|
||||
|
||||
|
||||
def _safe_dequant_moe_tensor(
|
||||
blocks: torch.Tensor,
|
||||
scales: torch.Tensor,
|
||||
*,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
rows_per_chunk: int = 4096,
|
||||
) -> torch.Tensor:
|
||||
"""Eager full-tensor MXFP4 dequant -> ``[E, H, 2*I]`` ``dtype``.
|
||||
|
||||
Allocates the output in its final transposed layout and writes in chunks
|
||||
"""
|
||||
blocks = blocks.to(torch.uint8)
|
||||
scales = scales.to(torch.int32) - 127
|
||||
|
||||
assert blocks.shape[:-1] == scales.shape, (
|
||||
f"{blocks.shape[:-1]=} does not match {scales.shape=}"
|
||||
)
|
||||
*prefix_shape, G, B = blocks.shape
|
||||
if len(prefix_shape) != 2:
|
||||
raise ValueError(f"expected 2-D prefix (E, 2*I); got {prefix_shape}")
|
||||
|
||||
E, D = prefix_shape
|
||||
val_per_row = G * B * 2 # this is H after dequant
|
||||
|
||||
rows_total = E * D * G
|
||||
blocks = blocks.reshape(rows_total, B)
|
||||
scales = scales.reshape(rows_total, 1)
|
||||
|
||||
lut = _fp4_lut(dtype, blocks.device)
|
||||
out = torch.empty(E, val_per_row, D, dtype=dtype, device=blocks.device)
|
||||
|
||||
for e in range(E):
|
||||
for d0 in range(0, D, rows_per_chunk):
|
||||
d1 = min(d0 + rows_per_chunk, D)
|
||||
r0 = e * D * G + d0 * G
|
||||
r1 = e * D * G + d1 * G
|
||||
blk = blocks[r0:r1]
|
||||
exp = scales[r0:r1]
|
||||
dec = torch.empty((d1 - d0) * G, B * 2, dtype=dtype, device=blocks.device)
|
||||
dec[:, 0::2] = lut[(blk & 0x0F).to(torch.long)]
|
||||
dec[:, 1::2] = lut[(blk >> 4).to(torch.long)]
|
||||
torch.ldexp(dec, exp, out=dec)
|
||||
out[e, :, d0:d1] = dec.view(d1 - d0, val_per_row).transpose(0, 1)
|
||||
del blk, exp, dec
|
||||
return out
|
||||
|
||||
|
||||
def _dequant_mxfp4_state_dict(sd: Dict[str, torch.Tensor], target_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
"""Eager-dequant every ``*_blocks``/``*_scales`` pair in ``sd`` in place."""
|
||||
pairs: List[Tuple[str, str, str]] = []
|
||||
for k in list(sd.keys()):
|
||||
if k.endswith("_blocks"):
|
||||
stem = k[: -len("_blocks")]
|
||||
sk = stem + "_scales"
|
||||
if sk in sd:
|
||||
pairs.append((stem, k, sk))
|
||||
|
||||
if not pairs:
|
||||
return sd
|
||||
|
||||
logging.info("Lens: dequantizing %d MXFP4 expert tensors -> %s", len(pairs), target_dtype)
|
||||
for stem, bk, sk in pairs:
|
||||
blocks = sd.pop(bk)
|
||||
scales = sd.pop(sk)
|
||||
sd[stem] = _safe_dequant_moe_tensor(blocks, scales, dtype=target_dtype)
|
||||
del blocks, scales
|
||||
|
||||
gc.collect()
|
||||
return sd
|
||||
|
||||
|
||||
class LensGptOssClipModel(nn.Module):
|
||||
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
|
||||
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None, **_):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
|
||||
super().__init__()
|
||||
model_options = dict(model_options or {})
|
||||
|
||||
operations = model_options.get("custom_operations")
|
||||
quant_config = model_options.get("quantization_metadata")
|
||||
if operations is None:
|
||||
if quant_config is not None:
|
||||
operations = comfy.ops.mixed_precision_ops(
|
||||
quant_config, dtype, full_precision_mm=True
|
||||
)
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
quant_config = model_options.get("quantization_metadata") or {}
|
||||
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
|
||||
self.operations = operations
|
||||
|
||||
cfg_overrides = model_options.get("gpt_oss_config", {})
|
||||
self.config = GptOss20BConfig(**cfg_overrides)
|
||||
self.selected_layers = tuple(
|
||||
model_options.get("selected_layers", LENS_SELECTED_LAYERS)
|
||||
)
|
||||
self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
|
||||
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
|
||||
|
||||
# mxfp4_runtime=True keeps experts packed and dequants per hit at forward.
|
||||
self.mxfp4_runtime = bool(model_options.get("mxfp4_runtime", False))
|
||||
|
||||
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
|
||||
self.num_layers = self.config.num_hidden_layers
|
||||
self.dtype = dtype
|
||||
@ -755,17 +594,6 @@ class LensGptOssClipModel(nn.Module):
|
||||
return flat, None, extra
|
||||
|
||||
def load_sd(self, sd):
|
||||
if any(k.startswith("model.") for k in sd):
|
||||
sd = {(k[len("model."):] if k.startswith("model.") else k): v for k, v in sd.items()}
|
||||
sd.pop("lm_head.weight", None)
|
||||
|
||||
if self.mxfp4_runtime:
|
||||
device = next(self.transformer.parameters()).device
|
||||
for layer in self.transformer.layers:
|
||||
layer.mlp.experts.switch_to_mxfp4(device=device)
|
||||
else:
|
||||
sd = _dequant_mxfp4_state_dict(sd, self.dtype)
|
||||
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=True)
|
||||
|
||||
|
||||
@ -774,7 +602,7 @@ class LensTEModel(sd1_clip.SD1ClipModel):
|
||||
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
|
||||
|
||||
|
||||
def lens_te(dtype_llama=None, llama_quantization_metadata=None, mxfp4_runtime=False):
|
||||
def lens_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class LensTEModel_(LensTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
mo = dict(model_options or {})
|
||||
@ -782,8 +610,6 @@ def lens_te(dtype_llama=None, llama_quantization_metadata=None, mxfp4_runtime=Fa
|
||||
mo["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if mxfp4_runtime:
|
||||
mo["mxfp4_runtime"] = True
|
||||
super().__init__(device=device, dtype=dtype, model_options=mo)
|
||||
|
||||
return LensTEModel_
|
||||
|
||||
Loading…
Reference in New Issue
Block a user