mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-18 23:55:08 +08:00
Merge branch 'Comfy-Org:master' into master
This commit is contained in:
commit
08063d2638
@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
|
|||||||
output_block[i:i + slice_size].copy_(block)
|
output_block[i:i + slice_size].copy_(block)
|
||||||
|
|
||||||
return output_fp4, to_blocked(output_block, flatten=False)
|
return output_fp4, to_blocked(output_block, flatten=False)
|
||||||
|
|
||||||
|
|
||||||
|
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
|
||||||
|
def roundup(x_val, multiple):
|
||||||
|
return ((x_val + multiple - 1) // multiple) * multiple
|
||||||
|
|
||||||
|
if pad_32x:
|
||||||
|
rows, cols = x.shape
|
||||||
|
padded_rows = roundup(rows, 32)
|
||||||
|
padded_cols = roundup(cols, 32)
|
||||||
|
if padded_rows != rows or padded_cols != cols:
|
||||||
|
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||||
|
|
||||||
|
F8_E4M3_MAX = 448.0
|
||||||
|
E8M0_BIAS = 127
|
||||||
|
BLOCK_SIZE = 32
|
||||||
|
|
||||||
|
rows, cols = x.shape
|
||||||
|
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
|
||||||
|
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
|
||||||
|
|
||||||
|
# E8M0 block scales (power-of-2 exponents)
|
||||||
|
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
|
||||||
|
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
|
||||||
|
block_scales_e8m0 = exp_biased.to(torch.uint8)
|
||||||
|
|
||||||
|
zero_mask = (max_abs == 0)
|
||||||
|
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
|
||||||
|
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
|
||||||
|
|
||||||
|
# Scale per-block then stochastic round
|
||||||
|
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
|
||||||
|
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
|
||||||
|
|
||||||
|
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
|
||||||
|
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)
|
||||||
|
|||||||
@ -1712,6 +1712,19 @@ def supports_nvfp4_compute(device=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def supports_mxfp8_compute(device=None):
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch_version_numeric < (2, 10):
|
||||||
|
return False
|
||||||
|
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
if props.major < 10:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def extended_fp16_support():
|
def extended_fp16_support():
|
||||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||||
if torch_version_numeric < (2, 7):
|
if torch_version_numeric < (2, 7):
|
||||||
|
|||||||
19
comfy/ops.py
19
comfy/ops.py
@ -857,6 +857,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
orig_shape=(self.out_features, self.in_features),
|
orig_shape=(self.out_features, self.in_features),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.quant_format == "mxfp8":
|
||||||
|
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
||||||
|
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||||
|
dtype=torch.uint8)
|
||||||
|
|
||||||
|
if block_scale is None:
|
||||||
|
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||||
|
|
||||||
|
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
||||||
|
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=block_scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.out_features, self.in_features),
|
||||||
|
)
|
||||||
|
|
||||||
elif self.quant_format == "nvfp4":
|
elif self.quant_format == "nvfp4":
|
||||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||||
@ -1006,12 +1022,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||||
|
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
|
||||||
|
|
||||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||||
logging.info("Using mixed precision operations")
|
logging.info("Using mixed precision operations")
|
||||||
disabled = set()
|
disabled = set()
|
||||||
if not nvfp4_compute:
|
if not nvfp4_compute:
|
||||||
disabled.add("nvfp4")
|
disabled.add("nvfp4")
|
||||||
|
if not mxfp8_compute:
|
||||||
|
disabled.add("mxfp8")
|
||||||
if not fp8_compute:
|
if not fp8_compute:
|
||||||
disabled.add("float8_e4m3fn")
|
disabled.add("float8_e4m3fn")
|
||||||
disabled.add("float8_e5m2")
|
disabled.add("float8_e5m2")
|
||||||
|
|||||||
@ -43,6 +43,18 @@ except ImportError as e:
|
|||||||
def get_layout_class(name):
|
def get_layout_class(name):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
_CK_MXFP8_AVAILABLE = False
|
||||||
|
if _CK_AVAILABLE:
|
||||||
|
try:
|
||||||
|
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
|
||||||
|
_CK_MXFP8_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
|
||||||
|
|
||||||
|
if not _CK_MXFP8_AVAILABLE:
|
||||||
|
class _CKMxfp8Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
|||||||
return qdata, params
|
return qdata, params
|
||||||
|
|
||||||
|
|
||||||
|
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
|
if tensor.dim() != 2:
|
||||||
|
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
|
||||||
|
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
orig_shape = tuple(tensor.shape)
|
||||||
|
|
||||||
|
padded_shape = cls.get_padded_shape(orig_shape)
|
||||||
|
needs_padding = padded_shape != orig_shape
|
||||||
|
|
||||||
|
if stochastic_rounding > 0:
|
||||||
|
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
|
||||||
|
else:
|
||||||
|
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
|
||||||
|
|
||||||
|
params = cls.Params(
|
||||||
|
scale=block_scale,
|
||||||
|
orig_dtype=orig_dtype,
|
||||||
|
orig_shape=orig_shape,
|
||||||
|
)
|
||||||
|
return qdata, params
|
||||||
|
|
||||||
|
|
||||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||||
@classmethod
|
@classmethod
|
||||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
|||||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||||
|
if _CK_MXFP8_AVAILABLE:
|
||||||
|
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||||
|
|
||||||
QUANT_ALGOS = {
|
QUANT_ALGOS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
@ -157,6 +196,14 @@ QUANT_ALGOS = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _CK_MXFP8_AVAILABLE:
|
||||||
|
QUANT_ALGOS["mxfp8"] = {
|
||||||
|
"storage_t": torch.float8_e4m3fn,
|
||||||
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
|
||||||
|
"group_size": 32,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Re-exports for backward compatibility
|
# Re-exports for backward compatibility
|
||||||
|
|||||||
@ -32,7 +32,7 @@ async def cache_control(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
|
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
|
||||||
response.headers.setdefault("Cache-Control", "no-cache")
|
response.headers.setdefault("Cache-Control", "no-store")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# Early return for non-image files - no cache headers needed
|
# Early return for non-image files - no cache headers needed
|
||||||
|
|||||||
@ -310,7 +310,7 @@ class PromptServer():
|
|||||||
@routes.get("/")
|
@routes.get("/")
|
||||||
async def get_root(request):
|
async def get_root(request):
|
||||||
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
||||||
response.headers['Cache-Control'] = 'no-cache'
|
response.headers['Cache-Control'] = 'no-store, must-revalidate'
|
||||||
response.headers["Pragma"] = "no-cache"
|
response.headers["Pragma"] = "no-cache"
|
||||||
response.headers["Expires"] = "0"
|
response.headers["Expires"] = "0"
|
||||||
return response
|
return response
|
||||||
|
|||||||
@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
|
|||||||
},
|
},
|
||||||
# JavaScript/CSS scenarios
|
# JavaScript/CSS scenarios
|
||||||
{
|
{
|
||||||
"name": "js_no_cache",
|
"name": "js_no_store",
|
||||||
"path": "/script.js",
|
"path": "/script.js",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-cache",
|
"expected_cache": "no-store",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "css_no_cache",
|
"name": "css_no_store",
|
||||||
"path": "/styles.css",
|
"path": "/styles.css",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-cache",
|
"expected_cache": "no-store",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "index_json_no_cache",
|
"name": "index_json_no_store",
|
||||||
"path": "/api/index.json",
|
"path": "/api/index.json",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-cache",
|
"expected_cache": "no-store",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "localized_index_json_no_cache",
|
"name": "localized_index_json_no_store",
|
||||||
"path": "/templates/index.zh.json",
|
"path": "/templates/index.zh.json",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-cache",
|
"expected_cache": "no-store",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
# Non-matching files
|
# Non-matching files
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user