mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-22 04:20:49 +08:00
Merge 5ba2d28b7f into 5ac3b26a7d
This commit is contained in:
commit
8b34f1c661
@ -124,6 +124,10 @@ We define 4 possible scaling parameters that should cover most recipes in the ne
|
|||||||
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
||||||
|--------|---------------|--------------|----------------|-----------------|-------------|
|
|--------|---------------|--------------|----------------|-----------------|-------------|
|
||||||
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
||||||
|
| int8_blockwise | int8 | float32 (per-block) | - | - | - |
|
||||||
|
|
||||||
|
For int8_blockwise with block_size=128 and weight shape (N, K):
|
||||||
|
- weight_scale shape: (N//128, K//128)
|
||||||
|
|
||||||
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
||||||
|
|
||||||
@ -131,7 +135,9 @@ You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
|||||||
|
|
||||||
The metadata stored alongside the checkpoint contains:
|
The metadata stored alongside the checkpoint contains:
|
||||||
- **format_version**: String to define a version of the standard
|
- **format_version**: String to define a version of the standard
|
||||||
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
|
- **layers**: A dictionary mapping layer names to their quantization configuration. Each layer's config is a dictionary with:
|
||||||
|
- **format**: Quantization format string that maps to the definitions found in `QUANT_ALGOS`
|
||||||
|
- **group_size** (optional): Block size for block-wise quantization schemes (e.g., int8_blockwise)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```json
|
```json
|
||||||
@ -139,9 +145,9 @@ Example:
|
|||||||
"_quantization_metadata": {
|
"_quantization_metadata": {
|
||||||
"format_version": "1.0",
|
"format_version": "1.0",
|
||||||
"layers": {
|
"layers": {
|
||||||
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
"model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
|
||||||
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
"model.layers.0.mlp.down_proj": {"format": "int8_blockwise", "group_size": 128},
|
||||||
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
"model.layers.1.mlp.up_proj": {"format": "int8_blockwise", "group_size": 256}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -54,6 +54,8 @@ def stochastic_rounding(value, dtype, seed=0):
|
|||||||
return value.to(dtype=torch.float16)
|
return value.to(dtype=torch.float16)
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
return value.to(dtype=torch.bfloat16)
|
return value.to(dtype=torch.bfloat16)
|
||||||
|
if dtype == torch.int8:
|
||||||
|
return value.to(dtype=torch.int8)
|
||||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||||
generator = torch.Generator(device=value.device)
|
generator = torch.Generator(device=value.device)
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
|
|||||||
1194
comfy/int8_kernels.py
Normal file
1194
comfy/int8_kernels.py
Normal file
File diff suppressed because it is too large
Load Diff
10
comfy/ops.py
10
comfy/ops.py
@ -552,12 +552,20 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
scale = state_dict.pop(weight_scale_key, None)
|
scale = state_dict.pop(weight_scale_key, None)
|
||||||
if scale is not None:
|
if scale is not None:
|
||||||
scale = scale.to(device)
|
scale = scale.to(device)
|
||||||
|
|
||||||
|
# Check for per-layer group_size override, otherwise use default from QUANT_ALGOS
|
||||||
|
layer_config = MixedPrecisionOps._layer_quant_config[layer_name]
|
||||||
|
group_size = layer_config.get("group_size", qconfig.get("group_size", None))
|
||||||
|
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': scale,
|
'scale': scale,
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||||
'block_size': qconfig.get("group_size", None),
|
'block_size': group_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if qconfig.get("asymmetric_layout", False):
|
||||||
|
layout_params['is_weight'] = True
|
||||||
|
|
||||||
if scale is not None:
|
if scale is not None:
|
||||||
manually_loaded_keys.append(weight_scale_key)
|
manually_loaded_keys.append(weight_scale_key)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user