Remove CK reference and ensure correct compute dtype

This commit is contained in:
lspindler 2025-10-27 07:55:44 +01:00
parent 5e9f335753
commit efb35035f3
4 changed files with 11 additions and 36 deletions

View File

@ -7,8 +7,7 @@ import logging
import torch import torch
def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."): def detect_layer_quantization(metadata):
# 1. Check for per-layer config in metadata
quant_key = "_quantization_metadata" quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata: if metadata is not None and quant_key in metadata:
quant_metadata = metadata.pop(quant_key) quant_metadata = metadata.pop(quant_key)
@ -18,13 +17,6 @@ def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_mode
return quant_metadata["layers"] return quant_metadata["layers"]
else: else:
raise ValueError(f"Invalid quantization metadata format") raise ValueError(f"Invalid quantization metadata format")
# 2. Check for legacy scaled_fp8 marker
scaled_fp8_key = f"{prefix}scaled_fp8"
if scaled_fp8_key in state_dict:
logging.debug("Detected legacy scaled_fp8 format, using legacy code path")
return None
return None return None
@ -724,7 +716,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
model_config.optimizations["fp8"] = True model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision) # Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix) layer_quant_config = detect_layer_quantization(metadata)
if layer_quant_config: if layer_quant_config:
model_config.layer_quant_config = layer_quant_config model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")

View File

@ -561,11 +561,9 @@ class MixedPrecisionOps(disable_weight_init):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
# If model_config.layer_quant_config exists, use new MixedPrecisionOps.
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
# MixedPrecisionOps._compute_dtype = compute_dtype # TODO MixedPrecisionOps._compute_dtype = compute_dtype
MixedPrecisionOps._compute_dtype = torch.bfloat16
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return MixedPrecisionOps return MixedPrecisionOps

View File

@ -75,21 +75,6 @@ def _copy_layout_params(params):
return new_params return new_params
try:
import comfy_kitchen as ck
ck.disable_backend("cutile")
_CK_AVAILABLE = True
logging.info("comfy_kitchen available for optimized quantization kernels")
except ImportError:
ck = None
_CK_AVAILABLE = False
logging.info("comfy_kitchen not available - using PyTorch fallbacks")
except Exception as e:
ck = None
_CK_AVAILABLE = False
logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks")
class QuantizedLayout: class QuantizedLayout:
""" """
Base class for quantization layouts. Base class for quantization layouts.
@ -372,13 +357,10 @@ class TensorCoreFP8Layout(QuantizedLayout):
scale = torch.tensor(scale) scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32) scale = scale.to(device=tensor.device, dtype=torch.float32)
if _CK_AVAILABLE and tensor.device.type == "cuda": lp_amax = torch.finfo(fp8_dtype).max
qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) tensor_scaled = tensor.float() / scale
else: torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
lp_amax = torch.finfo(fp8_dtype).max qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
tensor_scaled = tensor.float() / scale
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
layout_params = { layout_params = {
'scale': scale, 'scale': scale,

View File

@ -1330,7 +1330,10 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
else: else:
unet_dtype = dtype unet_dtype = dtype
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) if hasattr(model_config, "layer_quant_config"):
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False): if model_options.get("fp8_optimizations", False):