mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 15:32:32 +08:00
154 lines
4.8 KiB
Python
154 lines
4.8 KiB
Python
import argparse
|
|
import gc
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from safetensors import safe_open
|
|
from safetensors.torch import save_file
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
|
|
|
from comfy.quant_ops import QuantizedTensor
|
|
|
|
|
|
LAYOUT = "TensorCoreFP8E4M3Layout"
|
|
QUANT_CONFIG = {"format": "float8_e4m3fn"}
|
|
|
|
|
|
def is_packed_expert(key, tensor):
|
|
return (
|
|
key.endswith(".img_mlp.experts.gate_up_proj")
|
|
or key.endswith(".img_mlp.experts.down_proj")
|
|
) and tensor.ndim == 3
|
|
|
|
|
|
def is_split_expert(key):
|
|
return ".img_mlp.experts.gate_up_projs." in key or ".img_mlp.experts.down_projs." in key
|
|
|
|
|
|
def is_linear_weight(key, tensor):
|
|
return key.endswith(".weight") and tensor.ndim >= 2
|
|
|
|
|
|
def module_key_for_weight(key):
|
|
return key[:-len(".weight")]
|
|
|
|
|
|
def module_key_for_packed_expert(key):
|
|
return key.rsplit(".", 1)[0]
|
|
|
|
|
|
def quant_config_tensor():
|
|
return torch.tensor(list(json.dumps(QUANT_CONFIG).encode("utf-8")), dtype=torch.uint8)
|
|
|
|
|
|
def quantize_tensor(key, tensor):
|
|
quantized = QuantizedTensor.from_float(tensor, LAYOUT, scale="recalculate")
|
|
return quantized.state_dict(key)
|
|
|
|
|
|
def convert(input_path, output_path, overwrite=False):
|
|
input_path = Path(input_path)
|
|
output_path = Path(output_path)
|
|
tmp_path = output_path.with_suffix(output_path.suffix + ".tmp")
|
|
|
|
if not input_path.exists():
|
|
raise FileNotFoundError(f"input checkpoint does not exist: {input_path}")
|
|
if output_path.exists() and not overwrite:
|
|
raise FileExistsError(f"output checkpoint already exists: {output_path}")
|
|
if tmp_path.exists():
|
|
tmp_path.unlink()
|
|
|
|
start = time.time()
|
|
out = {}
|
|
packed_modules = set()
|
|
stats = {
|
|
"normal_quantized": 0,
|
|
"packed_quantized": 0,
|
|
"kept": 0,
|
|
"split_expert_keys": 0,
|
|
}
|
|
|
|
with safe_open(input_path, framework="pt", device="cpu") as src:
|
|
keys = list(src.keys())
|
|
total = len(keys)
|
|
print(f"source={input_path} keys={total}", flush=True)
|
|
|
|
for index, key in enumerate(keys, 1):
|
|
tensor = src.get_tensor(key)
|
|
|
|
if is_split_expert(key):
|
|
stats["split_expert_keys"] += 1
|
|
|
|
if is_packed_expert(key, tensor):
|
|
out.update(quantize_tensor(key, tensor))
|
|
packed_modules.add(module_key_for_packed_expert(key))
|
|
stats["packed_quantized"] += 1
|
|
elif is_linear_weight(key, tensor):
|
|
out.update(quantize_tensor(key, tensor))
|
|
out[f"{module_key_for_weight(key)}.comfy_quant"] = quant_config_tensor()
|
|
stats["normal_quantized"] += 1
|
|
else:
|
|
out[key] = tensor
|
|
stats["kept"] += 1
|
|
|
|
del tensor
|
|
if index % 25 == 0 or index == total:
|
|
elapsed = time.time() - start
|
|
print(
|
|
"processed "
|
|
f"{index}/{total} "
|
|
f"normal_quantized={stats['normal_quantized']} "
|
|
f"packed_quantized={stats['packed_quantized']} "
|
|
f"kept={stats['kept']} "
|
|
f"elapsed={elapsed:.1f}s",
|
|
flush=True,
|
|
)
|
|
gc.collect()
|
|
|
|
if stats["packed_quantized"] == 0:
|
|
raise RuntimeError("no packed Nucleus expert tensors were found in the input checkpoint")
|
|
if stats["split_expert_keys"] > 0:
|
|
raise RuntimeError(
|
|
"input checkpoint already contains split Nucleus expert tensors; "
|
|
"use the BF16/BF16-packed source checkpoint instead"
|
|
)
|
|
|
|
for module_key in packed_modules:
|
|
out[f"{module_key}.comfy_quant"] = quant_config_tensor()
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
print(f"saving {len(out)} tensors to {tmp_path}", flush=True)
|
|
save_file(out, str(tmp_path), metadata={"format": "pt"})
|
|
os.replace(tmp_path, output_path)
|
|
|
|
stats["packed_modules"] = len(packed_modules)
|
|
stats["seconds"] = round(time.time() - start, 1)
|
|
stats["output_gib"] = round(output_path.stat().st_size / (1024 ** 3), 2)
|
|
print(f"wrote {output_path}", flush=True)
|
|
print(json.dumps(stats, indent=2), flush=True)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description=(
|
|
"Convert a BF16 Nucleus-Image diffusion checkpoint to scaled FP8 "
|
|
"while preserving packed MoE expert tensors for grouped-mm."
|
|
)
|
|
)
|
|
parser.add_argument("input", help="BF16 Nucleus-Image diffusion checkpoint")
|
|
parser.add_argument("output", help="output FP8 safetensors path")
|
|
parser.add_argument("--overwrite", action="store_true", help="replace an existing output file")
|
|
args = parser.parse_args()
|
|
|
|
convert(args.input, args.output, overwrite=args.overwrite)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|