mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
Pivot from using mxfp4 to nvfp4
This commit is contained in:
parent
b429028ad2
commit
c431fd555b
198
comfy/ops.py
198
comfy/ops.py
@ -1333,6 +1333,204 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
self._buffers[key] = fn(buf)
|
self._buffers[key] = fn(buf)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
class MoEExperts(torch.nn.Module):
|
||||||
|
"""Container for E quantized expert weights, indexed via ``expert_weight(i)``.
|
||||||
|
|
||||||
|
Holds expert weights as 3D buffers/parameters.
|
||||||
|
|
||||||
|
State-dict layout (analogous to ``mixed_precision_ops.Linear`` with a
|
||||||
|
leading expert dim — exact storage shape is layout-specific)::
|
||||||
|
|
||||||
|
{prefix}.weight quant data (storage_t), leading dim = E
|
||||||
|
{prefix}.weight_scale block / per-tensor scale
|
||||||
|
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
|
||||||
|
{prefix}.bias [E, out_features] optional, bf16
|
||||||
|
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
|
||||||
|
|
||||||
|
Without ``comfy_quant`` the weight loads as a plain bf16 3D Parameter ``[E, out, in]``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_experts = num_experts
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||||
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
# Populated by _load_from_state_dict:
|
||||||
|
self.weight = None # bf16 fallback: 3D Parameter [E, out, in]
|
||||||
|
self.quant_format = None
|
||||||
|
self.layout_type = None
|
||||||
|
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||||
|
self._full_precision_mm_config = False
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_scale_param(self, state_dict, prefix, param_name, device,
|
||||||
|
manually_loaded_keys, dtype=None):
|
||||||
|
key = f"{prefix}{param_name}"
|
||||||
|
value = state_dict.pop(key, None)
|
||||||
|
if value is not None:
|
||||||
|
value = value.to(device=device)
|
||||||
|
if dtype is not None:
|
||||||
|
value = value.view(dtype=dtype)
|
||||||
|
manually_loaded_keys.append(key)
|
||||||
|
return value
|
||||||
|
|
||||||
|
# TODO: refactor to share more code with Linear._load_from_state_dict
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
device = self.factory_kwargs["device"]
|
||||||
|
layer_name = prefix.rstrip(".")
|
||||||
|
weight_key = f"{prefix}weight"
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is None:
|
||||||
|
logging.warning(f"Missing weight for MoEExperts layer {layer_name}")
|
||||||
|
return
|
||||||
|
manually_loaded_keys = [weight_key]
|
||||||
|
|
||||||
|
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||||
|
if layer_conf is not None:
|
||||||
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||||
|
manually_loaded_keys.append(f"{prefix}comfy_quant")
|
||||||
|
|
||||||
|
if layer_conf is None:
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.quant_format = layer_conf.get("format")
|
||||||
|
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||||
|
if not self._full_precision_mm:
|
||||||
|
self._full_precision_mm = self._full_precision_mm_config
|
||||||
|
|
||||||
|
if self.quant_format in MixedPrecisionOps._disabled:
|
||||||
|
self._full_precision_mm = True
|
||||||
|
|
||||||
|
if self.quant_format is None:
|
||||||
|
raise ValueError(f"Unknown quant format for MoEExperts layer {layer_name}")
|
||||||
|
|
||||||
|
qconfig = QUANT_ALGOS[self.quant_format]
|
||||||
|
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||||
|
|
||||||
|
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||||
|
ts = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||||
|
self.register_buffer("_tensor_scale", ts, persistent=False)
|
||||||
|
elif self.quant_format == "mxfp8":
|
||||||
|
bs = self._load_scale_param(state_dict, prefix, "weight_scale", device,
|
||||||
|
manually_loaded_keys, dtype=torch.uint8)
|
||||||
|
if bs is None:
|
||||||
|
raise ValueError(f"Missing MXFP8 block scales for MoEExperts layer {layer_name}")
|
||||||
|
self.register_buffer("_block_scale", bs.view(torch.float8_e8m0fnu), persistent=False)
|
||||||
|
elif self.quant_format == "nvfp4":
|
||||||
|
ts = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||||
|
bs = self._load_scale_param(state_dict, prefix, "weight_scale", device,
|
||||||
|
manually_loaded_keys, dtype=torch.float8_e4m3fn)
|
||||||
|
if ts is None or bs is None:
|
||||||
|
raise ValueError(f"Missing NVFP4 scales for MoEExperts layer {layer_name}")
|
||||||
|
self.register_buffer("_tensor_scale", ts, persistent=False)
|
||||||
|
self.register_buffer("_block_scale", bs, persistent=False)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported MoEExperts quant format: {self.quant_format}")
|
||||||
|
|
||||||
|
self.register_buffer(
|
||||||
|
"_qdata",
|
||||||
|
weight.to(device=device, dtype=qconfig["storage_t"]),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
|
for k in manually_loaded_keys:
|
||||||
|
if k in missing_keys:
|
||||||
|
missing_keys.remove(k)
|
||||||
|
|
||||||
|
def expert_weight(self, i: int):
|
||||||
|
"""Expert i's weight (Tensor or QuantizedTensor)."""
|
||||||
|
if self.quant_format is None:
|
||||||
|
return self.weight[i]
|
||||||
|
|
||||||
|
qdata = self._qdata[i]
|
||||||
|
layout_cls = get_layout_class(self.layout_type)
|
||||||
|
orig_shape = (self.out_features, self.in_features)
|
||||||
|
|
||||||
|
if self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||||
|
scale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=orig_shape,
|
||||||
|
)
|
||||||
|
elif self.quant_format == "mxfp8":
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=self._block_scale[i],
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=orig_shape,
|
||||||
|
)
|
||||||
|
elif self.quant_format == "nvfp4":
|
||||||
|
tscale = self._tensor_scale[i] if self._tensor_scale.dim() else self._tensor_scale
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=tscale,
|
||||||
|
block_scale=self._block_scale[i],
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=orig_shape,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quant format: {self.quant_format}")
|
||||||
|
return QuantizedTensor(qdata, self.layout_type, params)
|
||||||
|
|
||||||
|
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
|
||||||
|
"""Linear against expert ``i``'s weight (with optional bias)."""
|
||||||
|
qw = self.expert_weight(i)
|
||||||
|
bias = None
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = cast_to_input(self.bias[i], input, copy=False)
|
||||||
|
|
||||||
|
if isinstance(qw, QuantizedTensor):
|
||||||
|
use_fast = (
|
||||||
|
not self._full_precision_mm
|
||||||
|
and qw.layout_cls.supports_fast_matmul()
|
||||||
|
and input.dim() == 2
|
||||||
|
)
|
||||||
|
if use_fast:
|
||||||
|
qin = QuantizedTensor.from_float(input, self.layout_type)
|
||||||
|
return torch.nn.functional.linear(qin, qw, bias)
|
||||||
|
out = input @ qw.dequantize().t()
|
||||||
|
return out + bias if bias is not None else out
|
||||||
|
|
||||||
|
return torch.nn.functional.linear(input, qw, bias)
|
||||||
|
|
||||||
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
|
sd = destination if destination is not None else {}
|
||||||
|
if self.bias is not None:
|
||||||
|
sd[f"{prefix}bias"] = self.bias
|
||||||
|
if self.quant_format is None:
|
||||||
|
if self.weight is not None:
|
||||||
|
sd[f"{prefix}weight"] = self.weight
|
||||||
|
return sd
|
||||||
|
|
||||||
|
sd[f"{prefix}weight"] = self._qdata
|
||||||
|
if self.quant_format == "nvfp4":
|
||||||
|
sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8)
|
||||||
|
sd[f"{prefix}weight_scale_2"] = self._tensor_scale
|
||||||
|
elif self.quant_format == "mxfp8":
|
||||||
|
sd[f"{prefix}weight_scale"] = self._block_scale.view(torch.uint8)
|
||||||
|
elif self.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||||
|
sd[f"{prefix}weight_scale"] = self._tensor_scale
|
||||||
|
|
||||||
|
quant_conf = {"format": self.quant_format, "num_experts": self.num_experts}
|
||||||
|
if self._full_precision_mm_config:
|
||||||
|
quant_conf["full_precision_matrix_mult"] = True
|
||||||
|
sd[f"{prefix}comfy_quant"] = torch.tensor(
|
||||||
|
list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8
|
||||||
|
)
|
||||||
|
return sd
|
||||||
|
|
||||||
class Embedding(manual_cast.Embedding):
|
class Embedding(manual_cast.Embedding):
|
||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|||||||
@ -1366,10 +1366,7 @@ def detect_te_model(sd):
|
|||||||
return TEModel.GEMMA_3_4B
|
return TEModel.GEMMA_3_4B
|
||||||
return TEModel.GEMMA_2_2B
|
return TEModel.GEMMA_2_2B
|
||||||
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
|
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
|
||||||
if "model.layers.0.self_attn.sinks" in sd and (
|
if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd:
|
||||||
"model.layers.0.mlp.experts.gate_up_proj" in sd
|
|
||||||
or "model.layers.0.mlp.experts.gate_up_proj_blocks" in sd
|
|
||||||
):
|
|
||||||
return TEModel.GPT_OSS_20B
|
return TEModel.GPT_OSS_20B
|
||||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||||
@ -1554,8 +1551,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
||||||
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||||
elif te_model == TEModel.GPT_OSS_20B:
|
elif te_model == TEModel.GPT_OSS_20B:
|
||||||
mxfp4 = any("model.layers.0.mlp.experts.gate_up_proj_blocks" in sd for sd in clip_data)
|
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data))
|
||||||
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data), mxfp4_runtime=mxfp4)
|
|
||||||
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
|
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
|
||||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||||
elif te_model == TEModel.QWEN3_4B:
|
elif te_model == TEModel.QWEN3_4B:
|
||||||
|
|||||||
@ -859,10 +859,8 @@ class Lens(supported_models_base.BASE):
|
|||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
for hint in ("gpt_oss.transformer.", ""):
|
for hint in ("gpt_oss.transformer.", ""):
|
||||||
full_prefix = "{}{}".format(pref, hint)
|
full_prefix = "{}{}".format(pref, hint)
|
||||||
if "{}model.layers.0.self_attn.sinks".format(full_prefix) in state_dict:
|
if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict:
|
||||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
|
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
|
||||||
if "{}model.layers.0.mlp.experts.gate_up_proj_blocks".format(full_prefix) in state_dict:
|
|
||||||
detect["mxfp4_runtime"] = True
|
|
||||||
return supported_models_base.ClipTarget(
|
return supported_models_base.ClipTarget(
|
||||||
comfy.text_encoders.gpt_oss.LensTokenizer,
|
comfy.text_encoders.gpt_oss.LensTokenizer,
|
||||||
comfy.text_encoders.gpt_oss.lens_te(**detect),
|
comfy.text_encoders.gpt_oss.lens_te(**detect),
|
||||||
|
|||||||
@ -2,11 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import gc
|
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
from typing import Any, List, Optional, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -201,7 +199,7 @@ class GptOssTopKRouter(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GptOssExperts(nn.Module):
|
class GptOssExperts(nn.Module):
|
||||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
|
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_experts = config.num_local_experts
|
self.num_experts = config.num_local_experts
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -213,29 +211,8 @@ class GptOssExperts(nn.Module):
|
|||||||
H = self.hidden_size
|
H = self.hidden_size
|
||||||
I = self.intermediate_size
|
I = self.intermediate_size
|
||||||
|
|
||||||
self.gate_up_proj_bias = nn.Parameter(torch.empty(E, 2 * I, device=device, dtype=dtype))
|
self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
|
||||||
self.down_proj_bias = nn.Parameter(torch.empty(E, H, device=device, dtype=dtype))
|
self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)
|
||||||
self.gate_up_proj = nn.Parameter(torch.empty(E, H, 2 * I, device=device, dtype=dtype))
|
|
||||||
self.down_proj = nn.Parameter(torch.empty(E, I, H, device=device, dtype=dtype))
|
|
||||||
|
|
||||||
def switch_to_mxfp4(self, device=None):
|
|
||||||
"""Swap bf16 weight Parameters for uint8 MXFP4 packed buffers.
|
|
||||||
|
|
||||||
On-disk MXFP4 layout: ``[E, 2*I, G_up, 16]`` uint8 + ``[E, 2*I, G_up]``
|
|
||||||
uint8 (E8M0) for ``gate_up``; ``[E, H, G_down, 16]`` + ``[E, H, G_down]``
|
|
||||||
for ``down``. ``G_up * 32 = H``, ``G_down * 32 = I``.
|
|
||||||
"""
|
|
||||||
E, H, I = self.num_experts, self.hidden_size, self.intermediate_size
|
|
||||||
if H % 32 != 0 or I % 32 != 0:
|
|
||||||
raise ValueError(f"MXFP4 requires H, I divisible by 32; got H={H}, I={I}")
|
|
||||||
del self.gate_up_proj
|
|
||||||
del self.down_proj
|
|
||||||
G_up = H // 32
|
|
||||||
G_down = I // 32
|
|
||||||
self.register_buffer("gate_up_proj_blocks", torch.empty(E, 2 * I, G_up, 16, dtype=torch.uint8, device=device))
|
|
||||||
self.register_buffer("gate_up_proj_scales", torch.empty(E, 2 * I, G_up, dtype=torch.uint8, device=device))
|
|
||||||
self.register_buffer("down_proj_blocks", torch.empty(E, H, G_down, 16, dtype=torch.uint8, device=device))
|
|
||||||
self.register_buffer("down_proj_scales", torch.empty(E, H, G_down, dtype=torch.uint8, device=device))
|
|
||||||
|
|
||||||
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
|
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
|
||||||
gate = gate_up[..., ::2]
|
gate = gate_up[..., ::2]
|
||||||
@ -245,24 +222,6 @@ class GptOssExperts(nn.Module):
|
|||||||
glu = gate * torch.sigmoid(gate * self.alpha)
|
glu = gate * torch.sigmoid(gate * self.alpha)
|
||||||
return (up + 1) * glu
|
return (up + 1) * glu
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _dequant_one_expert(
|
|
||||||
blocks_e: torch.Tensor,
|
|
||||||
scales_e: torch.Tensor,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Dequant one expert's MXFP4 ``[D, G, 16]`` + ``[D, G]`` to ``[G*32, D]``."""
|
|
||||||
D, G, B = blocks_e.shape
|
|
||||||
val_per_row = G * 32
|
|
||||||
lut = _fp4_lut(dtype, blocks_e.device)
|
|
||||||
blocks_flat = blocks_e.reshape(D * G, B)
|
|
||||||
scales_flat = (scales_e.to(torch.int32) - 127).reshape(D * G, 1)
|
|
||||||
dec = torch.empty(D * G, B * 2, dtype=dtype, device=blocks_e.device)
|
|
||||||
dec[:, 0::2] = lut[(blocks_flat & 0x0F).to(torch.long)]
|
|
||||||
dec[:, 1::2] = lut[(blocks_flat >> 4).to(torch.long)]
|
|
||||||
torch.ldexp(dec, scales_flat, out=dec)
|
|
||||||
return dec.view(D, val_per_row).transpose(0, 1)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
|
||||||
N = hidden_states.shape[0]
|
N = hidden_states.shape[0]
|
||||||
top_k = router_indices.shape[-1]
|
top_k = router_indices.shape[-1]
|
||||||
@ -273,34 +232,15 @@ class GptOssExperts(nn.Module):
|
|||||||
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
|
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
|
||||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||||
|
|
||||||
is_mxfp4 = hasattr(self, "gate_up_proj_blocks")
|
|
||||||
|
|
||||||
for ei in expert_hit:
|
for ei in expert_hit:
|
||||||
expert_idx = int(ei.item())
|
expert_idx = int(ei.item())
|
||||||
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
||||||
current = hidden_states[token_idx]
|
current = hidden_states[token_idx]
|
||||||
|
|
||||||
if is_mxfp4:
|
gate_up = self.gate_up_proj.expert_linear(current, expert_idx)
|
||||||
gate_up_w = self._dequant_one_expert(
|
|
||||||
self.gate_up_proj_blocks[expert_idx],
|
|
||||||
self.gate_up_proj_scales[expert_idx],
|
|
||||||
current.dtype,
|
|
||||||
)
|
|
||||||
down_w = self._dequant_one_expert(
|
|
||||||
self.down_proj_blocks[expert_idx],
|
|
||||||
self.down_proj_scales[expert_idx],
|
|
||||||
current.dtype,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
gate_up_w = comfy.ops.cast_to_input(self.gate_up_proj[expert_idx], current, copy=False)
|
|
||||||
down_w = comfy.ops.cast_to_input(self.down_proj[expert_idx], current, copy=False)
|
|
||||||
|
|
||||||
gate_up_b = comfy.ops.cast_to_input(self.gate_up_proj_bias[expert_idx], current, copy=False)
|
|
||||||
down_b = comfy.ops.cast_to_input(self.down_proj_bias[expert_idx], current, copy=False)
|
|
||||||
|
|
||||||
gate_up = current @ gate_up_w + gate_up_b
|
|
||||||
gated = self._apply_gate(gate_up)
|
gated = self._apply_gate(gate_up)
|
||||||
expert_out = gated @ down_w + down_b
|
expert_out = self.down_proj.expert_linear(gated, expert_idx)
|
||||||
|
|
||||||
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
|
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
|
||||||
|
|
||||||
flat_idx = token_idx * top_k + top_k_pos
|
flat_idx = token_idx * top_k + top_k_pos
|
||||||
@ -310,10 +250,10 @@ class GptOssExperts(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GptOssMLP(nn.Module):
|
class GptOssMLP(nn.Module):
|
||||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
|
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
|
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
|
||||||
self.experts = GptOssExperts(config, device=device, dtype=dtype)
|
self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
B, S, H = hidden_states.shape
|
B, S, H = hidden_states.shape
|
||||||
@ -329,7 +269,7 @@ class GptOssDecoderLayer(nn.Module):
|
|||||||
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
|
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
|
||||||
self.mlp = GptOssMLP(config, device=device, dtype=dtype)
|
self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||||
self.layer_type = config.layer_types[layer_idx]
|
self.layer_type = config.layer_types[layer_idx]
|
||||||
@ -579,125 +519,24 @@ class LensTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# MXFP4 E2M1 LUT (1 sign + 2 exp + 1 mantissa).
|
|
||||||
_FP4_VALUES: Tuple[float, ...] = (
|
|
||||||
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
|
|
||||||
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_FP4_LUT_CACHE: Dict[Tuple[torch.dtype, str], torch.Tensor] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def _fp4_lut(dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
|
||||||
"""Cached per (dtype, device) FP4 lookup table — avoids per-call allocation."""
|
|
||||||
key = (dtype, str(device))
|
|
||||||
lut = _FP4_LUT_CACHE.get(key)
|
|
||||||
if lut is None:
|
|
||||||
lut = torch.tensor(_FP4_VALUES, dtype=dtype, device=device)
|
|
||||||
_FP4_LUT_CACHE[key] = lut
|
|
||||||
return lut
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_dequant_moe_tensor(
|
|
||||||
blocks: torch.Tensor,
|
|
||||||
scales: torch.Tensor,
|
|
||||||
*,
|
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
|
||||||
rows_per_chunk: int = 4096,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Eager full-tensor MXFP4 dequant -> ``[E, H, 2*I]`` ``dtype``.
|
|
||||||
|
|
||||||
Allocates the output in its final transposed layout and writes in chunks
|
|
||||||
"""
|
|
||||||
blocks = blocks.to(torch.uint8)
|
|
||||||
scales = scales.to(torch.int32) - 127
|
|
||||||
|
|
||||||
assert blocks.shape[:-1] == scales.shape, (
|
|
||||||
f"{blocks.shape[:-1]=} does not match {scales.shape=}"
|
|
||||||
)
|
|
||||||
*prefix_shape, G, B = blocks.shape
|
|
||||||
if len(prefix_shape) != 2:
|
|
||||||
raise ValueError(f"expected 2-D prefix (E, 2*I); got {prefix_shape}")
|
|
||||||
|
|
||||||
E, D = prefix_shape
|
|
||||||
val_per_row = G * B * 2 # this is H after dequant
|
|
||||||
|
|
||||||
rows_total = E * D * G
|
|
||||||
blocks = blocks.reshape(rows_total, B)
|
|
||||||
scales = scales.reshape(rows_total, 1)
|
|
||||||
|
|
||||||
lut = _fp4_lut(dtype, blocks.device)
|
|
||||||
out = torch.empty(E, val_per_row, D, dtype=dtype, device=blocks.device)
|
|
||||||
|
|
||||||
for e in range(E):
|
|
||||||
for d0 in range(0, D, rows_per_chunk):
|
|
||||||
d1 = min(d0 + rows_per_chunk, D)
|
|
||||||
r0 = e * D * G + d0 * G
|
|
||||||
r1 = e * D * G + d1 * G
|
|
||||||
blk = blocks[r0:r1]
|
|
||||||
exp = scales[r0:r1]
|
|
||||||
dec = torch.empty((d1 - d0) * G, B * 2, dtype=dtype, device=blocks.device)
|
|
||||||
dec[:, 0::2] = lut[(blk & 0x0F).to(torch.long)]
|
|
||||||
dec[:, 1::2] = lut[(blk >> 4).to(torch.long)]
|
|
||||||
torch.ldexp(dec, exp, out=dec)
|
|
||||||
out[e, :, d0:d1] = dec.view(d1 - d0, val_per_row).transpose(0, 1)
|
|
||||||
del blk, exp, dec
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _dequant_mxfp4_state_dict(sd: Dict[str, torch.Tensor], target_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
|
||||||
"""Eager-dequant every ``*_blocks``/``*_scales`` pair in ``sd`` in place."""
|
|
||||||
pairs: List[Tuple[str, str, str]] = []
|
|
||||||
for k in list(sd.keys()):
|
|
||||||
if k.endswith("_blocks"):
|
|
||||||
stem = k[: -len("_blocks")]
|
|
||||||
sk = stem + "_scales"
|
|
||||||
if sk in sd:
|
|
||||||
pairs.append((stem, k, sk))
|
|
||||||
|
|
||||||
if not pairs:
|
|
||||||
return sd
|
|
||||||
|
|
||||||
logging.info("Lens: dequantizing %d MXFP4 expert tensors -> %s", len(pairs), target_dtype)
|
|
||||||
for stem, bk, sk in pairs:
|
|
||||||
blocks = sd.pop(bk)
|
|
||||||
scales = sd.pop(sk)
|
|
||||||
sd[stem] = _safe_dequant_moe_tensor(blocks, scales, dtype=target_dtype)
|
|
||||||
del blocks, scales
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
return sd
|
|
||||||
|
|
||||||
|
|
||||||
class LensGptOssClipModel(nn.Module):
|
class LensGptOssClipModel(nn.Module):
|
||||||
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
|
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
|
||||||
|
|
||||||
def __init__(self, device="cpu", dtype=None, model_options=None, **_):
|
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model_options = dict(model_options or {})
|
model_options = dict(model_options or {})
|
||||||
|
|
||||||
operations = model_options.get("custom_operations")
|
operations = model_options.get("custom_operations")
|
||||||
quant_config = model_options.get("quantization_metadata")
|
|
||||||
if operations is None:
|
if operations is None:
|
||||||
if quant_config is not None:
|
quant_config = model_options.get("quantization_metadata") or {}
|
||||||
operations = comfy.ops.mixed_precision_ops(
|
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
|
||||||
quant_config, dtype, full_precision_mm=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
operations = comfy.ops.manual_cast
|
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
|
|
||||||
cfg_overrides = model_options.get("gpt_oss_config", {})
|
cfg_overrides = model_options.get("gpt_oss_config", {})
|
||||||
self.config = GptOss20BConfig(**cfg_overrides)
|
self.config = GptOss20BConfig(**cfg_overrides)
|
||||||
self.selected_layers = tuple(
|
self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
|
||||||
model_options.get("selected_layers", LENS_SELECTED_LAYERS)
|
|
||||||
)
|
|
||||||
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
|
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
|
||||||
|
|
||||||
# mxfp4_runtime=True keeps experts packed and dequants per hit at forward.
|
|
||||||
self.mxfp4_runtime = bool(model_options.get("mxfp4_runtime", False))
|
|
||||||
|
|
||||||
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
|
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
|
||||||
self.num_layers = self.config.num_hidden_layers
|
self.num_layers = self.config.num_hidden_layers
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
@ -755,17 +594,6 @@ class LensGptOssClipModel(nn.Module):
|
|||||||
return flat, None, extra
|
return flat, None, extra
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
if any(k.startswith("model.") for k in sd):
|
|
||||||
sd = {(k[len("model."):] if k.startswith("model.") else k): v for k, v in sd.items()}
|
|
||||||
sd.pop("lm_head.weight", None)
|
|
||||||
|
|
||||||
if self.mxfp4_runtime:
|
|
||||||
device = next(self.transformer.parameters()).device
|
|
||||||
for layer in self.transformer.layers:
|
|
||||||
layer.mlp.experts.switch_to_mxfp4(device=device)
|
|
||||||
else:
|
|
||||||
sd = _dequant_mxfp4_state_dict(sd, self.dtype)
|
|
||||||
|
|
||||||
return self.transformer.load_state_dict(sd, strict=False, assign=True)
|
return self.transformer.load_state_dict(sd, strict=False, assign=True)
|
||||||
|
|
||||||
|
|
||||||
@ -774,7 +602,7 @@ class LensTEModel(sd1_clip.SD1ClipModel):
|
|||||||
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
|
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
|
||||||
|
|
||||||
|
|
||||||
def lens_te(dtype_llama=None, llama_quantization_metadata=None, mxfp4_runtime=False):
|
def lens_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class LensTEModel_(LensTEModel):
|
class LensTEModel_(LensTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||||
mo = dict(model_options or {})
|
mo = dict(model_options or {})
|
||||||
@ -782,8 +610,6 @@ def lens_te(dtype_llama=None, llama_quantization_metadata=None, mxfp4_runtime=Fa
|
|||||||
mo["quantization_metadata"] = llama_quantization_metadata
|
mo["quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
if mxfp4_runtime:
|
|
||||||
mo["mxfp4_runtime"] = True
|
|
||||||
super().__init__(device=device, dtype=dtype, model_options=mo)
|
super().__init__(device=device, dtype=dtype, model_options=mo)
|
||||||
|
|
||||||
return LensTEModel_
|
return LensTEModel_
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user