From e0982a7174a9cacb0c3cd3fb6bd1f8e06d9aaf51 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 14 Mar 2026 15:25:09 -0700 Subject: [PATCH 1/2] fix: use no-store cache headers to prevent stale frontend chunks (#12911) After a frontend update (e.g. nightly build), browsers could load outdated cached index.html and JS/CSS chunks, causing dynamically imported modules to fail with MIME type errors and vite:preloadError. Hard refresh (Ctrl+Shift+R) was insufficient to fix the issue because Cache-Control: no-cache still allows the browser to cache and revalidate via ETags. aiohttp's FileResponse auto-generates ETags based on file mtime+size, which may not change after pip reinstall, so the browser gets 304 Not Modified and serves stale content. Clearing ALL site data in DevTools did fix it, confirming the HTTP cache was the root cause. The fix changes: - index.html: no-cache -> no-store, must-revalidate - JS/CSS/JSON entry points: no-cache -> no-store no-store instructs browsers to never cache these responses, ensuring every page load fetches the current index.html with correct chunk references. This is a small tradeoff (~5KB re-download per page load) for guaranteed correctness after updates. --- middleware/cache_middleware.py | 2 +- server.py | 2 +- tests-unit/server_test/test_cache_control.py | 16 ++++++++-------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/middleware/cache_middleware.py b/middleware/cache_middleware.py index f02135369..7a18821b0 100644 --- a/middleware/cache_middleware.py +++ b/middleware/cache_middleware.py @@ -32,7 +32,7 @@ async def cache_control( ) if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point: - response.headers.setdefault("Cache-Control", "no-cache") + response.headers.setdefault("Cache-Control", "no-store") return response # Early return for non-image files - no cache headers needed diff --git a/server.py b/server.py index 76904ebc9..85a8964be 100644 --- a/server.py +++ b/server.py @@ -310,7 +310,7 @@ class PromptServer(): @routes.get("/") async def get_root(request): response = web.FileResponse(os.path.join(self.web_root, "index.html")) - response.headers['Cache-Control'] = 'no-cache' + response.headers['Cache-Control'] = 'no-store, must-revalidate' response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" return response diff --git a/tests-unit/server_test/test_cache_control.py b/tests-unit/server_test/test_cache_control.py index fa68d9408..1d0366387 100644 --- a/tests-unit/server_test/test_cache_control.py +++ b/tests-unit/server_test/test_cache_control.py @@ -28,31 +28,31 @@ CACHE_SCENARIOS = [ }, # JavaScript/CSS scenarios { - "name": "js_no_cache", + "name": "js_no_store", "path": "/script.js", "status": 200, - "expected_cache": "no-cache", + "expected_cache": "no-store", "should_have_header": True, }, { - "name": "css_no_cache", + "name": "css_no_store", "path": "/styles.css", "status": 200, - "expected_cache": "no-cache", + "expected_cache": "no-store", "should_have_header": True, }, { - "name": "index_json_no_cache", + "name": "index_json_no_store", "path": "/api/index.json", "status": 200, - "expected_cache": "no-cache", + "expected_cache": "no-store", "should_have_header": True, }, { - "name": "localized_index_json_no_cache", + "name": "localized_index_json_no_store", "path": "/templates/index.zh.json", "status": 200, - "expected_cache": "no-cache", + "expected_cache": "no-store", "should_have_header": True, }, # Non-matching files From 1c5db7397d59eace38acef078b618c2f04e4e7fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sun, 15 Mar 2026 00:36:29 +0200 Subject: [PATCH 2/2] feat: Support mxfp8 (#12907) --- comfy/float.py | 36 ++++++++++++++++++++++++++++++ comfy/model_management.py | 13 +++++++++++ comfy/ops.py | 19 ++++++++++++++++ comfy/quant_ops.py | 47 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+) diff --git a/comfy/float.py b/comfy/float.py index 88c47cd80..184b3d6d0 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed= output_block[i:i + slice_size].copy_(block) return output_fp4, to_blocked(output_block, flatten=False) + + +def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0): + def roundup(x_val, multiple): + return ((x_val + multiple - 1) // multiple) * multiple + + if pad_32x: + rows, cols = x.shape + padded_rows = roundup(rows, 32) + padded_cols = roundup(cols, 32) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + + F8_E4M3_MAX = 448.0 + E8M0_BIAS = 127 + BLOCK_SIZE = 32 + + rows, cols = x.shape + x_blocked = x.reshape(rows, -1, BLOCK_SIZE) + max_abs = torch.amax(torch.abs(x_blocked), dim=-1) + + # E8M0 block scales (power-of-2 exponents) + scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127)) + exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254) + block_scales_e8m0 = exp_biased.to(torch.uint8) + + zero_mask = (max_abs == 0) + block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32) + block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32) + + # Scale per-block then stochastic round + data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols) + output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed) + + block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0) + return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu) diff --git a/comfy/model_management.py b/comfy/model_management.py index 4d5851bc0..bb77cff47 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1712,6 +1712,19 @@ def supports_nvfp4_compute(device=None): return True +def supports_mxfp8_compute(device=None): + if not is_nvidia(): + return False + + if torch_version_numeric < (2, 10): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): diff --git a/comfy/ops.py b/comfy/ops.py index 3f2da4e63..59c0df87d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -857,6 +857,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec orig_shape=(self.out_features, self.in_features), ) + elif self.quant_format == "mxfp8": + # MXFP8: E8M0 block scales stored as uint8 in safetensors + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, + dtype=torch.uint8) + + if block_scale is None: + raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") + + block_scale = block_scale.view(torch.float8_e8m0fnu) + + params = layout_cls.Params( + scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + elif self.quant_format == "nvfp4": # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) @@ -1006,12 +1022,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) + mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device) if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: logging.info("Using mixed precision operations") disabled = set() if not nvfp4_compute: disabled.add("nvfp4") + if not mxfp8_compute: + disabled.add("mxfp8") if not fp8_compute: disabled.add("float8_e4m3fn") disabled.add("float8_e5m2") diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 15a4f457b..42ee08fb2 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -43,6 +43,18 @@ except ImportError as e: def get_layout_class(name): return None +_CK_MXFP8_AVAILABLE = False +if _CK_AVAILABLE: + try: + from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout + _CK_MXFP8_AVAILABLE = True + except ImportError: + logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.") + +if not _CK_MXFP8_AVAILABLE: + class _CKMxfp8Layout: + pass + import comfy.float # ============================================================================== @@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): return qdata, params +class TensorCoreMXFP8Layout(_CKMxfp8Layout): + @classmethod + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if tensor.dim() != 2: + raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D") + + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) + + padded_shape = cls.get_padded_shape(orig_shape) + needs_padding = padded_shape != orig_shape + + if stochastic_rounding > 0: + qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding) + else: + qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding) + + params = cls.Params( + scale=block_scale, + orig_dtype=orig_dtype, + orig_shape=orig_shape, + ) + return qdata, params + + class TensorCoreNVFP4Layout(_CKNvfp4Layout): @classmethod def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): @@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) +if _CK_MXFP8_AVAILABLE: + register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) QUANT_ALGOS = { "float8_e4m3fn": { @@ -157,6 +196,14 @@ QUANT_ALGOS = { }, } +if _CK_MXFP8_AVAILABLE: + QUANT_ALGOS["mxfp8"] = { + "storage_t": torch.float8_e4m3fn, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreMXFP8Layout", + "group_size": 32, + } + # ============================================================================== # Re-exports for backward compatibility