Compare commits

..

1 Commits

Author SHA1 Message Date
xmarre
9047e20409
Merge e069617e54 into 4c4be1bba5 2026-03-14 20:27:20 +00:00
14 changed files with 129 additions and 723 deletions

View File

@ -83,8 +83,6 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")

View File

@ -209,39 +209,3 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
output_block[i:i + slice_size].copy_(block)
return output_fp4, to_blocked(output_block, flatten=False)
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
def roundup(x_val, multiple):
return ((x_val + multiple - 1) // multiple) * multiple
if pad_32x:
rows, cols = x.shape
padded_rows = roundup(rows, 32)
padded_cols = roundup(cols, 32)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
F8_E4M3_MAX = 448.0
E8M0_BIAS = 127
BLOCK_SIZE = 32
rows, cols = x.shape
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
# E8M0 block scales (power-of-2 exponents)
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
block_scales_e8m0 = exp_biased.to(torch.uint8)
zero_mask = (max_abs == 0)
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
# Scale per-block then stochastic round
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)

View File

@ -11,7 +11,6 @@ from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
import comfy.model_management
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
@ -537,7 +536,7 @@ class Decoder(nn.Module):
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample.to(comfy.model_management.intermediate_device()))
output.append(sample)
return
up_block = self.up_blocks[idx]

View File

@ -1050,12 +1050,6 @@ def intermediate_device():
else:
return torch.device("cpu")
def intermediate_dtype():
if args.fp16_intermediates:
return torch.float16
else:
return torch.float32
def vae_device():
if args.cpu_vae:
return torch.device("cpu")
@ -1718,19 +1712,6 @@ def supports_nvfp4_compute(device=None):
return True
def supports_mxfp8_compute(device=None):
if not is_nvidia():
return False
if torch_version_numeric < (2, 10):
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):

View File

@ -857,22 +857,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
@ -1022,15 +1006,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
disabled = set()
if not nvfp4_compute:
disabled.add("nvfp4")
if not mxfp8_compute:
disabled.add("mxfp8")
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")

View File

@ -43,18 +43,6 @@ except ImportError as e:
def get_layout_class(name):
return None
_CK_MXFP8_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
_CK_MXFP8_AVAILABLE = True
except ImportError:
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
if not _CK_MXFP8_AVAILABLE:
class _CKMxfp8Layout:
pass
import comfy.float
# ==============================================================================
@ -96,31 +84,6 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
params = cls.Params(
scale=block_scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
)
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
@ -174,8 +137,6 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
QUANT_ALGOS = {
"float8_e4m3fn": {
@ -196,14 +157,6 @@ QUANT_ALGOS = {
},
}
if _CK_MXFP8_AVAILABLE:
QUANT_ALGOS["mxfp8"] = {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
"group_size": 32,
}
# ==============================================================================
# Re-exports for backward compatibility

View File

@ -871,16 +871,13 @@ class VAE:
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels
def vae_output_dtype(self):
return model_management.intermediate_dtype()
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
@ -890,16 +887,16 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
else:
og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@ -908,7 +905,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@ -917,7 +914,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio
else:
@ -926,7 +923,7 @@ class VAE:
tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1:
@ -935,7 +932,7 @@ class VAE:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in, vae_options={}):
@ -953,9 +950,9 @@ class VAE:
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except Exception as e:
model_management.raise_non_oom(e)
@ -1028,9 +1025,9 @@ class VAE:
samples = None
for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out
except Exception as e:

View File

@ -60,311 +60,135 @@ class Unhashable:
pass
_PRIMITIVE_SIGNATURE_TYPES = (int, float, str, bool, bytes, type(None))
_CONTAINER_SIGNATURE_TYPES = (dict, list, tuple, set, frozenset)
_MAX_SIGNATURE_DEPTH = 32
_MAX_SIGNATURE_CONTAINER_VISITS = 10_000
_FAILED_SIGNATURE = object()
def _primitive_signature_sort_key(obj):
"""Return a deterministic ordering key for primitive signature values."""
obj_type = type(obj)
return ("primitive", obj_type.__module__, obj_type.__qualname__, repr(obj))
def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None):
def _sanitized_sort_key(obj, depth=0, max_depth=32):
"""Return a deterministic ordering key for sanitized built-in container content."""
if depth >= max_depth:
return ("MAX_DEPTH",)
if active is None:
active = set()
if memo is None:
memo = {}
obj_type = type(obj)
if obj_type is Unhashable:
return ("UNHASHABLE",)
elif obj_type in _PRIMITIVE_SIGNATURE_TYPES:
elif obj_type in (int, float, str, bool, bytes, type(None)):
return (obj_type.__module__, obj_type.__qualname__, repr(obj))
elif obj_type not in _CONTAINER_SIGNATURE_TYPES:
elif obj_type is dict:
items = [
(
_sanitized_sort_key(k, depth + 1, max_depth),
_sanitized_sort_key(v, depth + 1, max_depth),
)
for k, v in obj.items()
]
items.sort()
return ("dict", tuple(items))
elif obj_type is list:
return ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
elif obj_type is tuple:
return ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
elif obj_type is set:
return ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
elif obj_type is frozenset:
return ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
else:
return (obj_type.__module__, obj_type.__qualname__, "OPAQUE")
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if obj_id in active:
return ("CYCLE",)
active.add(obj_id)
try:
if obj_type is dict:
items = [
(
_sanitized_sort_key(k, depth + 1, max_depth, active, memo),
_sanitized_sort_key(v, depth + 1, max_depth, active, memo),
)
for k, v in obj.items()
]
items.sort()
result = ("dict", tuple(items))
elif obj_type is list:
result = ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
elif obj_type is tuple:
result = ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
elif obj_type is set:
result = ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
else:
result = ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
finally:
active.discard(obj_id)
def _sanitize_signature_input(obj, depth=0, max_depth=32, seen=None):
"""Normalize signature inputs to safe built-in containers.
memo[obj_id] = result
return result
def _signature_to_hashable_impl(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None):
"""Canonicalize signature inputs directly into their final hashable form."""
Preserves built-in container type, replaces opaque runtime values with
Unhashable(), and stops safely on cycles or excessive depth.
"""
if depth >= max_depth:
return _FAILED_SIGNATURE
return Unhashable()
if active is None:
active = set()
if memo is None:
memo = {}
if budget is None:
budget = {"remaining": _MAX_SIGNATURE_CONTAINER_VISITS}
if seen is None:
seen = set()
obj_type = type(obj)
if obj_type in _PRIMITIVE_SIGNATURE_TYPES:
return obj, _primitive_signature_sort_key(obj)
if obj_type is Unhashable or obj_type not in _CONTAINER_SIGNATURE_TYPES:
return _FAILED_SIGNATURE
if obj_type in (dict, list, tuple, set, frozenset):
obj_id = id(obj)
if obj_id in seen:
return Unhashable()
next_seen = seen | {obj_id}
obj_id = id(obj)
if obj_id in memo:
return memo[obj_id]
if obj_id in active:
return _FAILED_SIGNATURE
budget["remaining"] -= 1
if budget["remaining"] < 0:
return _FAILED_SIGNATURE
active.add(obj_id)
try:
if obj_type is dict:
try:
items = list(obj.items())
except RuntimeError:
return _FAILED_SIGNATURE
ordered_items = []
for key, value in items:
key_result = _signature_to_hashable_impl(key, depth + 1, max_depth, active, memo, budget)
if key_result is _FAILED_SIGNATURE:
return _FAILED_SIGNATURE
value_result = _signature_to_hashable_impl(value, depth + 1, max_depth, active, memo, budget)
if value_result is _FAILED_SIGNATURE:
return _FAILED_SIGNATURE
key_value, key_sort = key_result
value_value, value_sort = value_result
ordered_items.append((((key_sort, value_sort)), (key_value, value_value)))
ordered_items.sort(key=lambda item: item[0])
for index in range(1, len(ordered_items)):
previous_sort_key, previous_item = ordered_items[index - 1]
current_sort_key, current_item = ordered_items[index]
if previous_sort_key == current_sort_key and previous_item != current_item:
return _FAILED_SIGNATURE
value = ("dict", tuple(item for _, item in ordered_items))
sort_key = ("dict", tuple(sort_key for sort_key, _ in ordered_items))
elif obj_type is list or obj_type is tuple:
try:
items = list(obj)
except RuntimeError:
return _FAILED_SIGNATURE
child_results = []
for item in items:
child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget)
if child_result is _FAILED_SIGNATURE:
return _FAILED_SIGNATURE
child_results.append(child_result)
container_tag = "list" if obj_type is list else "tuple"
value = (container_tag, tuple(child for child, _ in child_results))
sort_key = (container_tag, tuple(child_sort for _, child_sort in child_results))
else:
try:
items = list(obj)
except RuntimeError:
return _FAILED_SIGNATURE
ordered_items = []
for item in items:
child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget)
if child_result is _FAILED_SIGNATURE:
return _FAILED_SIGNATURE
child_value, child_sort = child_result
ordered_items.append((child_sort, child_value))
ordered_items.sort(key=lambda item: item[0])
for index in range(1, len(ordered_items)):
previous_sort_key, previous_value = ordered_items[index - 1]
current_sort_key, current_value = ordered_items[index]
if previous_sort_key == current_sort_key and previous_value != current_value:
return _FAILED_SIGNATURE
container_tag = "set" if obj_type is set else "frozenset"
value = (container_tag, tuple(child_value for _, child_value in ordered_items))
sort_key = (container_tag, tuple(child_sort for child_sort, _ in ordered_items))
finally:
active.discard(obj_id)
memo[obj_id] = (value, sort_key)
return memo[obj_id]
def _signature_to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
"""Build the final cache-signature representation in one fail-closed pass."""
result = _signature_to_hashable_impl(obj, budget={"remaining": max_nodes})
if result is _FAILED_SIGNATURE:
if obj_type in (int, float, str, bool, bytes, type(None)):
return obj
elif obj_type is dict:
sanitized_items = [
(
_sanitize_signature_input(key, depth + 1, max_depth, next_seen),
_sanitize_signature_input(value, depth + 1, max_depth, next_seen),
)
for key, value in obj.items()
]
sanitized_items.sort(
key=lambda kv: (
_sanitized_sort_key(kv[0], depth + 1, max_depth),
_sanitized_sort_key(kv[1], depth + 1, max_depth),
)
)
return {key: value for key, value in sanitized_items}
elif obj_type is list:
return [_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj]
elif obj_type is tuple:
return tuple(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj)
elif obj_type is set:
return {_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj}
elif obj_type is frozenset:
return frozenset(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj)
else:
# Execution-cache signatures should be built from prompt-safe values.
# If a custom node injects a runtime object here, mark it unhashable so
# the node won't reuse stale cache entries across runs, but do not walk
# the foreign object and risk crashing on custom container semantics.
return Unhashable()
return result[0]
def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
def to_hashable(obj, depth=0, max_depth=32, seen=None):
"""Convert sanitized prompt inputs into a stable hashable representation.
The input is expected to already be sanitized to plain built-in containers,
but this function still fails safe for anything unexpected. Traversal is
iterative and memoized so shared built-in substructures do not trigger
exponential re-walks during cache-key construction.
Preserves built-in container type and stops safely on cycles or excessive depth.
"""
obj_type = type(obj)
if obj_type in _PRIMITIVE_SIGNATURE_TYPES or obj_type is Unhashable:
return obj
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
if depth >= max_depth:
return Unhashable()
memo = {}
active = set()
snapshots = {}
sort_memo = {}
processed = 0
stack = [(obj, False)]
if seen is None:
seen = set()
def resolve_value(value):
"""Resolve a child value from the completed memo table when available."""
value_type = type(value)
if value_type in _PRIMITIVE_SIGNATURE_TYPES or value_type is Unhashable:
return value
return memo.get(id(value), Unhashable())
# Restrict recursion to plain built-in containers. Some custom nodes insert
# runtime objects into prompt inputs for dynamic graph paths; walking those
# objects as generic Mappings / Sequences is unsafe and can destabilize the
# cache signature builder.
obj_type = type(obj)
if obj_type in (int, float, str, bool, bytes, type(None)):
return obj
def resolve_unordered_values(current_items, container_tag):
"""Resolve a set-like container or fail closed if ordering is ambiguous."""
try:
ordered_items = [
(_sanitized_sort_key(item, memo=sort_memo), resolve_value(item))
for item in current_items
]
ordered_items.sort(key=lambda item: item[0])
except RuntimeError:
if obj_type in (dict, list, tuple, set, frozenset):
obj_id = id(obj)
if obj_id in seen:
return Unhashable()
seen = seen | {obj_id}
for index in range(1, len(ordered_items)):
previous_key, previous_value = ordered_items[index - 1]
current_key, current_value = ordered_items[index]
if previous_key == current_key and previous_value != current_value:
return Unhashable()
return (container_tag, tuple(value for _, value in ordered_items))
while stack:
current, expanded = stack.pop()
current_type = type(current)
if current_type in _PRIMITIVE_SIGNATURE_TYPES or current_type is Unhashable:
continue
if current_type not in _CONTAINER_SIGNATURE_TYPES:
memo[id(current)] = Unhashable()
continue
current_id = id(current)
if current_id in memo:
continue
if expanded:
active.discard(current_id)
try:
if current_type is dict:
items = snapshots.pop(current_id, None)
if items is None:
items = list(current.items())
memo[current_id] = (
"dict",
tuple((resolve_value(k), resolve_value(v)) for k, v in items),
)
elif current_type is list:
items = snapshots.pop(current_id, None)
if items is None:
items = list(current)
memo[current_id] = ("list", tuple(resolve_value(item) for item in items))
elif current_type is tuple:
items = snapshots.pop(current_id, None)
if items is None:
items = list(current)
memo[current_id] = ("tuple", tuple(resolve_value(item) for item in items))
elif current_type is set:
items = snapshots.pop(current_id, None)
if items is None:
items = list(current)
memo[current_id] = resolve_unordered_values(items, "set")
else:
items = snapshots.pop(current_id, None)
if items is None:
items = list(current)
memo[current_id] = resolve_unordered_values(items, "frozenset")
except RuntimeError:
memo[current_id] = Unhashable()
continue
if current_id in active:
memo[current_id] = Unhashable()
continue
processed += 1
if processed > max_nodes:
return Unhashable()
active.add(current_id)
stack.append((current, True))
if current_type is dict:
try:
items = list(current.items())
snapshots[current_id] = items
except RuntimeError:
memo[current_id] = Unhashable()
active.discard(current_id)
continue
for key, value in reversed(items):
stack.append((value, False))
stack.append((key, False))
else:
try:
items = list(current)
snapshots[current_id] = items
except RuntimeError:
memo[current_id] = Unhashable()
active.discard(current_id)
continue
for item in reversed(items):
stack.append((item, False))
return memo.get(id(obj), Unhashable())
if obj_type is dict:
return (
"dict",
frozenset(
(
to_hashable(k, depth + 1, max_depth, seen),
to_hashable(v, depth + 1, max_depth, seen),
)
for k, v in obj.items()
),
)
elif obj_type is list:
return ("list", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
elif obj_type is tuple:
return ("tuple", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
elif obj_type is set:
return ("set", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
elif obj_type is frozenset:
return ("frozenset", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
else:
return Unhashable()
class CacheKeySetID(CacheKeySet):
"""Cache-key strategy that keys nodes by node id and class type."""
@ -414,7 +238,7 @@ class CacheKeySetInputSignature(CacheKeySet):
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
for ancestor_id in ancestors:
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
return _signature_to_hashable(signature)
return to_hashable(signature)
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
"""Build the cache-signature fragment for a node's immediate inputs.
@ -437,7 +261,7 @@ class CacheKeySetInputSignature(CacheKeySet):
ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, inputs[key]))
signature.append((key, _sanitize_signature_input(inputs[key])))
return signature
# This function returns a list of all ancestors of the given node. The order of the list is

View File

@ -32,7 +32,7 @@ async def cache_control(
)
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
response.headers.setdefault("Cache-Control", "no-store")
response.headers.setdefault("Cache-Control", "no-cache")
return response
# Early return for non-image files - no cache headers needed

View File

@ -1724,8 +1724,6 @@ class LoadImage:
output_masks = []
w, h = None, None
dtype = comfy.model_management.intermediate_dtype()
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
@ -1750,8 +1748,8 @@ class LoadImage:
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image.to(dtype=dtype))
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
if img.format == "MPO":
break # ignore all frames except the first one for MPO format

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.41.20
comfyui-frontend-package==1.41.19
comfyui-workflow-templates==0.9.21
comfyui-embedded-docs==0.4.3
torch

View File

@ -310,7 +310,7 @@ class PromptServer():
@routes.get("/")
async def get_root(request):
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
response.headers['Cache-Control'] = 'no-store, must-revalidate'
response.headers['Cache-Control'] = 'no-cache'
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response

View File

@ -1,289 +0,0 @@
"""Unit tests for cache-signature canonicalization hardening."""
import asyncio
import importlib
import sys
import types
import pytest
class _DummyNode:
"""Minimal node stub used to satisfy cache-signature class lookups."""
@staticmethod
def INPUT_TYPES():
"""Return a minimal empty input schema for unit tests."""
return {"required": {}}
class _FakeDynPrompt:
"""Small DynamicPrompt stand-in with only the methods these tests need."""
def __init__(self, nodes_by_id):
"""Store test nodes by id."""
self._nodes_by_id = nodes_by_id
def has_node(self, node_id):
"""Return whether the fake prompt contains the requested node."""
return node_id in self._nodes_by_id
def get_node(self, node_id):
"""Return the stored node payload for the requested id."""
return self._nodes_by_id[node_id]
class _FakeIsChangedCache:
"""Async stub for `is_changed` lookups used by cache-key generation."""
def __init__(self, values):
"""Store canned `is_changed` responses keyed by node id."""
self._values = values
async def get(self, node_id):
"""Return the canned `is_changed` value for a node."""
return self._values[node_id]
class _OpaqueValue:
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
@pytest.fixture
def caching_module(monkeypatch):
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
torch_module = types.ModuleType("torch")
psutil_module = types.ModuleType("psutil")
nodes_module = types.ModuleType("nodes")
nodes_module.NODE_CLASS_MAPPINGS = {}
graph_module = types.ModuleType("comfy_execution.graph")
class DynamicPrompt:
"""Placeholder graph type so the caching module can import cleanly."""
pass
graph_module.DynamicPrompt = DynamicPrompt
monkeypatch.setitem(sys.modules, "torch", torch_module)
monkeypatch.setitem(sys.modules, "psutil", psutil_module)
monkeypatch.setitem(sys.modules, "nodes", nodes_module)
monkeypatch.setitem(sys.modules, "comfy_execution.graph", graph_module)
monkeypatch.delitem(sys.modules, "comfy_execution.caching", raising=False)
module = importlib.import_module("comfy_execution.caching")
module = importlib.reload(module)
return module, nodes_module
def test_signature_to_hashable_handles_shared_builtin_substructures(caching_module):
"""Shared built-in substructures should canonicalize without collapsing to Unhashable."""
caching, _ = caching_module
shared = [{"value": 1}, {"value": 2}]
signature = caching._signature_to_hashable([shared, shared])
assert signature[0] == "list"
assert signature[1][0] == signature[1][1]
assert signature[1][0][0] == "list"
assert signature[1][0][1][0] == ("dict", (("value", 1),))
assert signature[1][0][1][1] == ("dict", (("value", 2),))
def test_signature_to_hashable_fails_closed_on_opaque_values(caching_module):
"""Opaque values should collapse the full signature to Unhashable immediately."""
caching, _ = caching_module
signature = caching._signature_to_hashable(["safe", object()])
assert isinstance(signature, caching.Unhashable)
def test_signature_to_hashable_stops_descending_after_failure(caching_module, monkeypatch):
"""Once canonicalization fails, later recursive descent should stop immediately."""
caching, _ = caching_module
original = caching._signature_to_hashable_impl
marker = object()
marker_seen = False
def tracking_canonicalize(obj, *args, **kwargs):
"""Track whether recursion reaches the nested marker after failure."""
nonlocal marker_seen
if obj is marker:
marker_seen = True
return original(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_signature_to_hashable_impl", tracking_canonicalize)
signature = caching._signature_to_hashable([object(), [marker]])
assert isinstance(signature, caching.Unhashable)
assert marker_seen is False
def test_signature_to_hashable_snapshots_list_before_recursing(caching_module, monkeypatch):
"""List canonicalization should read a point-in-time snapshot before recursive descent."""
caching, _ = caching_module
original = caching._signature_to_hashable_impl
marker = ("marker",)
values = [marker, 2]
def mutating_canonicalize(obj, *args, **kwargs):
"""Mutate the live list during recursion to verify snapshot-based traversal."""
if obj is marker:
values[1] = 3
return original(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize)
signature = caching._signature_to_hashable(values)
assert signature == ("list", (("tuple", ("marker",)), 2))
assert values[1] == 3
def test_signature_to_hashable_snapshots_dict_before_recursing(caching_module, monkeypatch):
"""Dict canonicalization should read a point-in-time snapshot before recursive descent."""
caching, _ = caching_module
original = caching._signature_to_hashable_impl
marker = ("marker",)
values = {"first": marker, "second": 2}
def mutating_canonicalize(obj, *args, **kwargs):
"""Mutate the live dict during recursion to verify snapshot-based traversal."""
if obj is marker:
values["second"] = 3
return original(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize)
signature = caching._signature_to_hashable(values)
assert signature == ("dict", (("first", ("tuple", ("marker",))), ("second", 2)))
assert values["second"] == 3
@pytest.mark.parametrize(
"container_factory",
[
lambda marker: [marker],
lambda marker: (marker,),
lambda marker: {marker},
lambda marker: frozenset({marker}),
lambda marker: {marker: "value"},
],
)
def test_signature_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
"""Traversal RuntimeError should degrade canonicalization to Unhashable."""
caching, _ = caching_module
original = caching._signature_to_hashable_impl
marker = object()
def raising_canonicalize(obj, *args, **kwargs):
"""Raise a traversal RuntimeError for the marker value and delegate otherwise."""
if obj is marker:
raise RuntimeError("container changed during iteration")
return original(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_signature_to_hashable_impl", raising_canonicalize)
signature = caching._signature_to_hashable(container_factory(marker))
assert isinstance(signature, caching.Unhashable)
def test_to_hashable_handles_shared_builtin_substructures(caching_module):
"""The legacy helper should still hash sanitized built-ins stably when used directly."""
caching, _ = caching_module
shared = [{"value": 1}, {"value": 2}]
sanitized = [shared, shared]
hashable = caching.to_hashable(sanitized)
assert hashable[0] == "list"
assert hashable[1][0] == hashable[1][1]
assert hashable[1][0][0] == "list"
@pytest.mark.parametrize(
"container_factory",
[
set,
frozenset,
],
)
def test_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
"""Traversal RuntimeError should degrade unordered hash conversion to Unhashable."""
caching, _ = caching_module
def raising_sort_key(obj, *args, **kwargs):
"""Raise a traversal RuntimeError while unordered values are canonicalized."""
raise RuntimeError("container changed during iteration")
monkeypatch.setattr(caching, "_sanitized_sort_key", raising_sort_key)
hashable = caching.to_hashable(container_factory({"value"}))
assert isinstance(hashable, caching.Unhashable)
def test_signature_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module):
"""Ambiguous dict sort ties should fail closed instead of depending on input order."""
caching, _ = caching_module
ambiguous = {
_OpaqueValue(): _OpaqueValue(),
_OpaqueValue(): _OpaqueValue(),
}
sanitized = caching._signature_to_hashable(ambiguous)
assert isinstance(sanitized, caching.Unhashable)
@pytest.mark.parametrize(
"container_factory",
[
set,
frozenset,
],
)
def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, container_factory):
"""Ambiguous unordered values should fail closed instead of depending on iteration order."""
caching, _ = caching_module
container = container_factory({_OpaqueValue(), _OpaqueValue()})
hashable = caching.to_hashable(container)
assert isinstance(hashable, caching.Unhashable)
def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch):
"""Tainted full signatures should fail closed before `to_hashable()` runs."""
caching, nodes_module = caching_module
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
monkeypatch.setattr(
caching,
"to_hashable",
lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"),
)
is_changed_value = []
is_changed_value.append(is_changed_value)
dynprompt = _FakeDynPrompt(
{
"node": {
"class_type": "UnitTestNode",
"inputs": {"value": 5},
}
}
)
key_set = caching.CacheKeySetInputSignature(
dynprompt,
["node"],
_FakeIsChangedCache({"node": is_changed_value}),
)
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
assert isinstance(signature, caching.Unhashable)

View File

@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
},
# JavaScript/CSS scenarios
{
"name": "js_no_store",
"name": "js_no_cache",
"path": "/script.js",
"status": 200,
"expected_cache": "no-store",
"expected_cache": "no-cache",
"should_have_header": True,
},
{
"name": "css_no_store",
"name": "css_no_cache",
"path": "/styles.css",
"status": 200,
"expected_cache": "no-store",
"expected_cache": "no-cache",
"should_have_header": True,
},
{
"name": "index_json_no_store",
"name": "index_json_no_cache",
"path": "/api/index.json",
"status": 200,
"expected_cache": "no-store",
"expected_cache": "no-cache",
"should_have_header": True,
},
{
"name": "localized_index_json_no_store",
"name": "localized_index_json_no_cache",
"path": "/templates/index.zh.json",
"status": 200,
"expected_cache": "no-store",
"expected_cache": "no-cache",
"should_have_header": True,
},
# Non-matching files