mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 16:32:34 +08:00
97 lines
3.1 KiB
Python
97 lines
3.1 KiB
Python
import torch
|
|
import logging
|
|
from typing import Dict, Optional
|
|
|
|
import comfy.ops
|
|
from modelopt.torch.quantization.nn import QuantModuleRegistry, TensorQuantizer
|
|
|
|
|
|
# FP8 E4M3 configuration
|
|
FP8_CFG = {
|
|
"quant_cfg": {
|
|
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
|
|
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
|
|
"default": {"enable": False},
|
|
},
|
|
"algorithm": "max",
|
|
}
|
|
|
|
|
|
def register_comfy_ops():
|
|
"""Register ComfyUI operations with ModelOptimizer."""
|
|
op = comfy.ops.disable_weight_init.Linear
|
|
op_name = op.__name__
|
|
|
|
if op in QuantModuleRegistry:
|
|
logging.debug("ComfyUI Linear already registered with ModelOptimizer")
|
|
return
|
|
|
|
# Register ComfyUI Linear using the same handler as torch.nn.Linear
|
|
QuantModuleRegistry.register(
|
|
{op: f"comfy.{op_name}"}
|
|
)(QuantModuleRegistry._registry[getattr(torch.nn, op_name)])
|
|
|
|
logging.info("Registered ComfyUI Linear with ModelOptimizer")
|
|
|
|
def log_quant_summary(model: torch.nn.Module, log_level=logging.INFO):
|
|
count = 0
|
|
for name, mod in model.named_modules():
|
|
if isinstance(mod, TensorQuantizer):
|
|
logging.log(log_level, f"{name:80} {mod}")
|
|
count += 1
|
|
logging.log(log_level, f"{count} TensorQuantizers found in model")
|
|
|
|
def extract_amax_values(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
|
amax_dict = {}
|
|
|
|
for name, module in model.named_modules():
|
|
if not isinstance(module, TensorQuantizer):
|
|
continue
|
|
if not module.is_enabled:
|
|
continue
|
|
if hasattr(module, '_amax') and module._amax is not None:
|
|
amax = module._amax
|
|
if not isinstance(amax, torch.Tensor):
|
|
amax = torch.tensor(amax, dtype=torch.float32)
|
|
|
|
amax_dict[name] = amax.clone().cpu()
|
|
logging.debug(f"Extracted amax from {name}: {amax.item():.6f}")
|
|
|
|
logging.info(f"Extracted amax values from {len(amax_dict)} quantizers")
|
|
return amax_dict
|
|
|
|
|
|
def save_amax_dict(amax_dict: Dict[str, torch.Tensor], output_path: str, metadata: Optional[Dict] = None):
|
|
import json
|
|
from datetime import datetime
|
|
|
|
logging.info(f"Saving {len(amax_dict)} amax values to {output_path}")
|
|
|
|
# Convert tensors to Python floats for JSON serialization
|
|
amax_values = {}
|
|
for key, value in amax_dict.items():
|
|
if isinstance(value, torch.Tensor):
|
|
# Convert to float (scalar) or list
|
|
if value.numel() == 1:
|
|
amax_values[key] = float(value.item())
|
|
else:
|
|
amax_values[key] = value.cpu().numpy().tolist()
|
|
else:
|
|
amax_values[key] = float(value)
|
|
|
|
# Build output with metadata
|
|
output_dict = {
|
|
"metadata": {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"num_layers": len(amax_values),
|
|
**(metadata or {})
|
|
},
|
|
"amax_values": amax_values
|
|
}
|
|
|
|
# Save as formatted JSON for easy inspection
|
|
with open(output_path, 'w') as f:
|
|
json.dump(output_dict, f, indent=2, sort_keys=True)
|
|
|
|
logging.info(f"✓ Amax values saved to {output_path}")
|