This commit is contained in:
Gavin Li 2025-12-14 11:02:46 +01:00 committed by GitHub
commit 8b34f1c661
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 4693 additions and 32 deletions

View File

@ -124,6 +124,10 @@ We define 4 possible scaling parameters that should cover most recipes in the ne
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale | | Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|--------|---------------|--------------|----------------|-----------------|-------------| |--------|---------------|--------------|----------------|-----------------|-------------|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) | | float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
| int8_blockwise | int8 | float32 (per-block) | - | - | - |
For int8_blockwise with block_size=128 and weight shape (N, K):
- weight_scale shape: (N//128, K//128)
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS). You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
@ -131,7 +135,9 @@ You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
The metadata stored alongside the checkpoint contains: The metadata stored alongside the checkpoint contains:
- **format_version**: String to define a version of the standard - **format_version**: String to define a version of the standard
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`. - **layers**: A dictionary mapping layer names to their quantization configuration. Each layer's config is a dictionary with:
- **format**: Quantization format string that maps to the definitions found in `QUANT_ALGOS`
- **group_size** (optional): Block size for block-wise quantization schemes (e.g., int8_blockwise)
Example: Example:
```json ```json
@ -139,9 +145,9 @@ Example:
"_quantization_metadata": { "_quantization_metadata": {
"format_version": "1.0", "format_version": "1.0",
"layers": { "layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn", "model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
"model.layers.0.mlp.down_proj": "float8_e4m3fn", "model.layers.0.mlp.down_proj": {"format": "int8_blockwise", "group_size": 128},
"model.layers.1.mlp.up_proj": "float8_e4m3fn" "model.layers.1.mlp.up_proj": {"format": "int8_blockwise", "group_size": 256}
} }
} }
} }

View File

@ -54,6 +54,8 @@ def stochastic_rounding(value, dtype, seed=0):
return value.to(dtype=torch.float16) return value.to(dtype=torch.float16)
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16) return value.to(dtype=torch.bfloat16)
if dtype == torch.int8:
return value.to(dtype=torch.int8)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device) generator = torch.Generator(device=value.device)
generator.manual_seed(seed) generator.manual_seed(seed)

1194
comfy/int8_kernels.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -552,12 +552,20 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = state_dict.pop(weight_scale_key, None) scale = state_dict.pop(weight_scale_key, None)
if scale is not None: if scale is not None:
scale = scale.to(device) scale = scale.to(device)
# Check for per-layer group_size override, otherwise use default from QUANT_ALGOS
layer_config = MixedPrecisionOps._layer_quant_config[layer_name]
group_size = layer_config.get("group_size", qconfig.get("group_size", None))
layout_params = { layout_params = {
'scale': scale, 'scale': scale,
'orig_dtype': MixedPrecisionOps._compute_dtype, 'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None), 'block_size': group_size,
} }
if qconfig.get("asymmetric_layout", False):
layout_params['is_weight'] = True
if scale is not None: if scale is not None:
manually_loaded_keys.append(weight_scale_key) manually_loaded_keys.append(weight_scale_key)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff