mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 08:22:36 +08:00
Remove CK reference and ensure correct compute dtype
This commit is contained in:
parent
5e9f335753
commit
efb35035f3
@ -7,8 +7,7 @@ import logging
|
||||
import torch
|
||||
|
||||
|
||||
def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_model."):
|
||||
# 1. Check for per-layer config in metadata
|
||||
def detect_layer_quantization(metadata):
|
||||
quant_key = "_quantization_metadata"
|
||||
if metadata is not None and quant_key in metadata:
|
||||
quant_metadata = metadata.pop(quant_key)
|
||||
@ -18,13 +17,6 @@ def detect_layer_quantization(state_dict, metadata, prefix="model.diffusion_mode
|
||||
return quant_metadata["layers"]
|
||||
else:
|
||||
raise ValueError(f"Invalid quantization metadata format")
|
||||
|
||||
# 2. Check for legacy scaled_fp8 marker
|
||||
scaled_fp8_key = f"{prefix}scaled_fp8"
|
||||
if scaled_fp8_key in state_dict:
|
||||
logging.debug("Detected legacy scaled_fp8 format, using legacy code path")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -724,7 +716,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
# Detect per-layer quantization (mixed precision)
|
||||
layer_quant_config = detect_layer_quantization(state_dict, metadata, unet_key_prefix)
|
||||
layer_quant_config = detect_layer_quantization(metadata)
|
||||
if layer_quant_config:
|
||||
model_config.layer_quant_config = layer_quant_config
|
||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||
|
||||
@ -561,11 +561,9 @@ class MixedPrecisionOps(disable_weight_init):
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||
# If model_config.layer_quant_config exists, use new MixedPrecisionOps.
|
||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
||||
# MixedPrecisionOps._compute_dtype = compute_dtype # TODO
|
||||
MixedPrecisionOps._compute_dtype = torch.bfloat16
|
||||
MixedPrecisionOps._compute_dtype = compute_dtype
|
||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||
return MixedPrecisionOps
|
||||
|
||||
|
||||
@ -75,21 +75,6 @@ def _copy_layout_params(params):
|
||||
return new_params
|
||||
|
||||
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
ck.disable_backend("cutile")
|
||||
_CK_AVAILABLE = True
|
||||
logging.info("comfy_kitchen available for optimized quantization kernels")
|
||||
except ImportError:
|
||||
ck = None
|
||||
_CK_AVAILABLE = False
|
||||
logging.info("comfy_kitchen not available - using PyTorch fallbacks")
|
||||
except Exception as e:
|
||||
ck = None
|
||||
_CK_AVAILABLE = False
|
||||
logging.warning(f"comfy_kitchen import failed: {e} - using PyTorch fallbacks")
|
||||
|
||||
|
||||
class QuantizedLayout:
|
||||
"""
|
||||
Base class for quantization layouts.
|
||||
@ -372,13 +357,10 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
if _CK_AVAILABLE and tensor.device.type == "cuda":
|
||||
qdata = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype)
|
||||
else:
|
||||
lp_amax = torch.finfo(fp8_dtype).max
|
||||
tensor_scaled = tensor.float() / scale
|
||||
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||
qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
|
||||
lp_amax = torch.finfo(fp8_dtype).max
|
||||
tensor_scaled = tensor.float() / scale
|
||||
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||
qdata = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
|
||||
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
|
||||
@ -1330,7 +1330,10 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
if hasattr(model_config, "layer_quant_config"):
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||
if model_options.get("fp8_optimizations", False):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user