mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-23 10:03:36 +08:00
Compare commits
17 Commits
9047e20409
...
a2f739ea5d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a2f739ea5d | ||
|
|
a6624a9afd | ||
|
|
0b512198e8 | ||
|
|
9feb26928c | ||
|
|
fadd79ad48 | ||
|
|
77bc7bdd6b | ||
|
|
117afbc1d7 | ||
|
|
064eec2278 | ||
|
|
aceaa5e579 | ||
|
|
0904cc3fe5 | ||
|
|
763089f681 | ||
|
|
4941cd046e | ||
|
|
1693dabc8f | ||
|
|
c711b8f437 | ||
|
|
08063d2638 | ||
|
|
1c5db7397d | ||
|
|
e0982a7174 |
@ -83,6 +83,8 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
|
||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
||||
|
||||
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
|
||||
|
||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
|
||||
output_block[i:i + slice_size].copy_(block)
|
||||
|
||||
return output_fp4, to_blocked(output_block, flatten=False)
|
||||
|
||||
|
||||
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
|
||||
def roundup(x_val, multiple):
|
||||
return ((x_val + multiple - 1) // multiple) * multiple
|
||||
|
||||
if pad_32x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 32)
|
||||
padded_cols = roundup(cols, 32)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
|
||||
F8_E4M3_MAX = 448.0
|
||||
E8M0_BIAS = 127
|
||||
BLOCK_SIZE = 32
|
||||
|
||||
rows, cols = x.shape
|
||||
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
|
||||
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
|
||||
|
||||
# E8M0 block scales (power-of-2 exponents)
|
||||
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
|
||||
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
|
||||
block_scales_e8m0 = exp_biased.to(torch.uint8)
|
||||
|
||||
zero_mask = (max_abs == 0)
|
||||
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
|
||||
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
|
||||
|
||||
# Scale per-block then stochastic round
|
||||
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
|
||||
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
|
||||
|
||||
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
|
||||
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)
|
||||
|
||||
@ -11,6 +11,7 @@ from .causal_conv3d import CausalConv3d
|
||||
from .pixel_norm import PixelNorm
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@ -536,7 +537,7 @@ class Decoder(nn.Module):
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
output.append(sample)
|
||||
output.append(sample.to(comfy.model_management.intermediate_device()))
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
|
||||
@ -1050,6 +1050,12 @@ def intermediate_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def intermediate_dtype():
|
||||
if args.fp16_intermediates:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def vae_device():
|
||||
if args.cpu_vae:
|
||||
return torch.device("cpu")
|
||||
@ -1712,6 +1718,19 @@ def supports_nvfp4_compute(device=None):
|
||||
|
||||
return True
|
||||
|
||||
def supports_mxfp8_compute(device=None):
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
if torch_version_numeric < (2, 10):
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major < 10:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def extended_fp16_support():
|
||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||
if torch_version_numeric < (2, 7):
|
||||
|
||||
19
comfy/ops.py
19
comfy/ops.py
@ -857,6 +857,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "mxfp8":
|
||||
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.uint8)
|
||||
|
||||
if block_scale is None:
|
||||
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||
|
||||
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "nvfp4":
|
||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||
@ -1006,12 +1022,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
|
||||
|
||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||
logging.info("Using mixed precision operations")
|
||||
disabled = set()
|
||||
if not nvfp4_compute:
|
||||
disabled.add("nvfp4")
|
||||
if not mxfp8_compute:
|
||||
disabled.add("mxfp8")
|
||||
if not fp8_compute:
|
||||
disabled.add("float8_e4m3fn")
|
||||
disabled.add("float8_e5m2")
|
||||
|
||||
@ -43,6 +43,18 @@ except ImportError as e:
|
||||
def get_layout_class(name):
|
||||
return None
|
||||
|
||||
_CK_MXFP8_AVAILABLE = False
|
||||
if _CK_AVAILABLE:
|
||||
try:
|
||||
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
|
||||
_CK_MXFP8_AVAILABLE = True
|
||||
except ImportError:
|
||||
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
|
||||
|
||||
if not _CK_MXFP8_AVAILABLE:
|
||||
class _CKMxfp8Layout:
|
||||
pass
|
||||
|
||||
import comfy.float
|
||||
|
||||
# ==============================================================================
|
||||
@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if tensor.dim() != 2:
|
||||
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
|
||||
|
||||
orig_dtype = tensor.dtype
|
||||
orig_shape = tuple(tensor.shape)
|
||||
|
||||
padded_shape = cls.get_padded_shape(orig_shape)
|
||||
needs_padding = padded_shape != orig_shape
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
|
||||
else:
|
||||
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
|
||||
|
||||
params = cls.Params(
|
||||
scale=block_scale,
|
||||
orig_dtype=orig_dtype,
|
||||
orig_shape=orig_shape,
|
||||
)
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||
|
||||
QUANT_ALGOS = {
|
||||
"float8_e4m3fn": {
|
||||
@ -157,6 +196,14 @@ QUANT_ALGOS = {
|
||||
},
|
||||
}
|
||||
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
QUANT_ALGOS["mxfp8"] = {
|
||||
"storage_t": torch.float8_e4m3fn,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
|
||||
"group_size": 32,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
|
||||
27
comfy/sd.py
27
comfy/sd.py
@ -871,13 +871,16 @@ class VAE:
|
||||
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
||||
return pixels
|
||||
|
||||
def vae_output_dtype(self):
|
||||
return model_management.intermediate_dtype()
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
output = self.process_output(
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
@ -887,16 +890,16 @@ class VAE:
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||
if samples.ndim == 3:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
else:
|
||||
og_shape = samples.shape
|
||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
@ -905,7 +908,7 @@ class VAE:
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
@ -914,7 +917,7 @@ class VAE:
|
||||
|
||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||
if self.latent_dim == 1:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
out_channels = self.latent_channels
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
else:
|
||||
@ -923,7 +926,7 @@ class VAE:
|
||||
tile_x = tile_x // extra_channel_size
|
||||
overlap = overlap // extra_channel_size
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
|
||||
|
||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
if self.latent_dim == 1:
|
||||
@ -932,7 +935,7 @@ class VAE:
|
||||
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
@ -950,9 +953,9 @@ class VAE:
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
@ -1025,9 +1028,9 @@ class VAE:
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -60,135 +60,311 @@ class Unhashable:
|
||||
pass
|
||||
|
||||
|
||||
def _sanitized_sort_key(obj, depth=0, max_depth=32):
|
||||
_PRIMITIVE_SIGNATURE_TYPES = (int, float, str, bool, bytes, type(None))
|
||||
_CONTAINER_SIGNATURE_TYPES = (dict, list, tuple, set, frozenset)
|
||||
_MAX_SIGNATURE_DEPTH = 32
|
||||
_MAX_SIGNATURE_CONTAINER_VISITS = 10_000
|
||||
_FAILED_SIGNATURE = object()
|
||||
|
||||
|
||||
def _primitive_signature_sort_key(obj):
|
||||
"""Return a deterministic ordering key for primitive signature values."""
|
||||
obj_type = type(obj)
|
||||
return ("primitive", obj_type.__module__, obj_type.__qualname__, repr(obj))
|
||||
|
||||
|
||||
def _sanitized_sort_key(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None):
|
||||
"""Return a deterministic ordering key for sanitized built-in container content."""
|
||||
if depth >= max_depth:
|
||||
return ("MAX_DEPTH",)
|
||||
|
||||
if active is None:
|
||||
active = set()
|
||||
if memo is None:
|
||||
memo = {}
|
||||
|
||||
obj_type = type(obj)
|
||||
if obj_type is Unhashable:
|
||||
return ("UNHASHABLE",)
|
||||
elif obj_type in (int, float, str, bool, bytes, type(None)):
|
||||
elif obj_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||
return (obj_type.__module__, obj_type.__qualname__, repr(obj))
|
||||
elif obj_type is dict:
|
||||
items = [
|
||||
(
|
||||
_sanitized_sort_key(k, depth + 1, max_depth),
|
||||
_sanitized_sort_key(v, depth + 1, max_depth),
|
||||
)
|
||||
for k, v in obj.items()
|
||||
]
|
||||
items.sort()
|
||||
return ("dict", tuple(items))
|
||||
elif obj_type is list:
|
||||
return ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
|
||||
elif obj_type is tuple:
|
||||
return ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
|
||||
elif obj_type is set:
|
||||
return ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
|
||||
elif obj_type is frozenset:
|
||||
return ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
|
||||
else:
|
||||
elif obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||
return (obj_type.__module__, obj_type.__qualname__, "OPAQUE")
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
if obj_id in active:
|
||||
return ("CYCLE",)
|
||||
|
||||
def _sanitize_signature_input(obj, depth=0, max_depth=32, seen=None):
|
||||
"""Normalize signature inputs to safe built-in containers.
|
||||
|
||||
Preserves built-in container type, replaces opaque runtime values with
|
||||
Unhashable(), and stops safely on cycles or excessive depth.
|
||||
"""
|
||||
if depth >= max_depth:
|
||||
return Unhashable()
|
||||
|
||||
if seen is None:
|
||||
seen = set()
|
||||
|
||||
obj_type = type(obj)
|
||||
if obj_type in (dict, list, tuple, set, frozenset):
|
||||
obj_id = id(obj)
|
||||
if obj_id in seen:
|
||||
return Unhashable()
|
||||
next_seen = seen | {obj_id}
|
||||
|
||||
if obj_type in (int, float, str, bool, bytes, type(None)):
|
||||
return obj
|
||||
elif obj_type is dict:
|
||||
sanitized_items = [
|
||||
(
|
||||
_sanitize_signature_input(key, depth + 1, max_depth, next_seen),
|
||||
_sanitize_signature_input(value, depth + 1, max_depth, next_seen),
|
||||
)
|
||||
for key, value in obj.items()
|
||||
]
|
||||
sanitized_items.sort(
|
||||
key=lambda kv: (
|
||||
_sanitized_sort_key(kv[0], depth + 1, max_depth),
|
||||
_sanitized_sort_key(kv[1], depth + 1, max_depth),
|
||||
)
|
||||
)
|
||||
return {key: value for key, value in sanitized_items}
|
||||
elif obj_type is list:
|
||||
return [_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj]
|
||||
elif obj_type is tuple:
|
||||
return tuple(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj)
|
||||
elif obj_type is set:
|
||||
return {_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj}
|
||||
elif obj_type is frozenset:
|
||||
return frozenset(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj)
|
||||
else:
|
||||
# Execution-cache signatures should be built from prompt-safe values.
|
||||
# If a custom node injects a runtime object here, mark it unhashable so
|
||||
# the node won't reuse stale cache entries across runs, but do not walk
|
||||
# the foreign object and risk crashing on custom container semantics.
|
||||
return Unhashable()
|
||||
|
||||
def to_hashable(obj, depth=0, max_depth=32, seen=None):
|
||||
"""Convert sanitized prompt inputs into a stable hashable representation.
|
||||
|
||||
Preserves built-in container type and stops safely on cycles or excessive depth.
|
||||
"""
|
||||
if depth >= max_depth:
|
||||
return Unhashable()
|
||||
|
||||
if seen is None:
|
||||
seen = set()
|
||||
|
||||
# Restrict recursion to plain built-in containers. Some custom nodes insert
|
||||
# runtime objects into prompt inputs for dynamic graph paths; walking those
|
||||
# objects as generic Mappings / Sequences is unsafe and can destabilize the
|
||||
# cache signature builder.
|
||||
obj_type = type(obj)
|
||||
if obj_type in (int, float, str, bool, bytes, type(None)):
|
||||
return obj
|
||||
|
||||
if obj_type in (dict, list, tuple, set, frozenset):
|
||||
obj_id = id(obj)
|
||||
if obj_id in seen:
|
||||
return Unhashable()
|
||||
seen = seen | {obj_id}
|
||||
|
||||
if obj_type is dict:
|
||||
return (
|
||||
"dict",
|
||||
frozenset(
|
||||
active.add(obj_id)
|
||||
try:
|
||||
if obj_type is dict:
|
||||
items = [
|
||||
(
|
||||
to_hashable(k, depth + 1, max_depth, seen),
|
||||
to_hashable(v, depth + 1, max_depth, seen),
|
||||
_sanitized_sort_key(k, depth + 1, max_depth, active, memo),
|
||||
_sanitized_sort_key(v, depth + 1, max_depth, active, memo),
|
||||
)
|
||||
for k, v in obj.items()
|
||||
),
|
||||
)
|
||||
elif obj_type is list:
|
||||
return ("list", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
|
||||
elif obj_type is tuple:
|
||||
return ("tuple", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
|
||||
elif obj_type is set:
|
||||
return ("set", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
|
||||
elif obj_type is frozenset:
|
||||
return ("frozenset", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
|
||||
else:
|
||||
]
|
||||
items.sort()
|
||||
result = ("dict", tuple(items))
|
||||
elif obj_type is list:
|
||||
result = ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
|
||||
elif obj_type is tuple:
|
||||
result = ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj))
|
||||
elif obj_type is set:
|
||||
result = ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
|
||||
else:
|
||||
result = ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth, active, memo) for i in obj)))
|
||||
finally:
|
||||
active.discard(obj_id)
|
||||
|
||||
memo[obj_id] = result
|
||||
return result
|
||||
|
||||
|
||||
def _signature_to_hashable_impl(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, active=None, memo=None, budget=None):
|
||||
"""Canonicalize signature inputs directly into their final hashable form."""
|
||||
if depth >= max_depth:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
if active is None:
|
||||
active = set()
|
||||
if memo is None:
|
||||
memo = {}
|
||||
if budget is None:
|
||||
budget = {"remaining": _MAX_SIGNATURE_CONTAINER_VISITS}
|
||||
|
||||
obj_type = type(obj)
|
||||
if obj_type in _PRIMITIVE_SIGNATURE_TYPES:
|
||||
return obj, _primitive_signature_sort_key(obj)
|
||||
if obj_type is Unhashable or obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
if obj_id in active:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
budget["remaining"] -= 1
|
||||
if budget["remaining"] < 0:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
active.add(obj_id)
|
||||
try:
|
||||
if obj_type is dict:
|
||||
try:
|
||||
items = list(obj.items())
|
||||
except RuntimeError:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
ordered_items = []
|
||||
for key, value in items:
|
||||
key_result = _signature_to_hashable_impl(key, depth + 1, max_depth, active, memo, budget)
|
||||
if key_result is _FAILED_SIGNATURE:
|
||||
return _FAILED_SIGNATURE
|
||||
value_result = _signature_to_hashable_impl(value, depth + 1, max_depth, active, memo, budget)
|
||||
if value_result is _FAILED_SIGNATURE:
|
||||
return _FAILED_SIGNATURE
|
||||
key_value, key_sort = key_result
|
||||
value_value, value_sort = value_result
|
||||
ordered_items.append((((key_sort, value_sort)), (key_value, value_value)))
|
||||
|
||||
ordered_items.sort(key=lambda item: item[0])
|
||||
for index in range(1, len(ordered_items)):
|
||||
previous_sort_key, previous_item = ordered_items[index - 1]
|
||||
current_sort_key, current_item = ordered_items[index]
|
||||
if previous_sort_key == current_sort_key and previous_item != current_item:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
value = ("dict", tuple(item for _, item in ordered_items))
|
||||
sort_key = ("dict", tuple(sort_key for sort_key, _ in ordered_items))
|
||||
elif obj_type is list or obj_type is tuple:
|
||||
try:
|
||||
items = list(obj)
|
||||
except RuntimeError:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
child_results = []
|
||||
for item in items:
|
||||
child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget)
|
||||
if child_result is _FAILED_SIGNATURE:
|
||||
return _FAILED_SIGNATURE
|
||||
child_results.append(child_result)
|
||||
|
||||
container_tag = "list" if obj_type is list else "tuple"
|
||||
value = (container_tag, tuple(child for child, _ in child_results))
|
||||
sort_key = (container_tag, tuple(child_sort for _, child_sort in child_results))
|
||||
else:
|
||||
try:
|
||||
items = list(obj)
|
||||
except RuntimeError:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
ordered_items = []
|
||||
for item in items:
|
||||
child_result = _signature_to_hashable_impl(item, depth + 1, max_depth, active, memo, budget)
|
||||
if child_result is _FAILED_SIGNATURE:
|
||||
return _FAILED_SIGNATURE
|
||||
child_value, child_sort = child_result
|
||||
ordered_items.append((child_sort, child_value))
|
||||
|
||||
ordered_items.sort(key=lambda item: item[0])
|
||||
for index in range(1, len(ordered_items)):
|
||||
previous_sort_key, previous_value = ordered_items[index - 1]
|
||||
current_sort_key, current_value = ordered_items[index]
|
||||
if previous_sort_key == current_sort_key and previous_value != current_value:
|
||||
return _FAILED_SIGNATURE
|
||||
|
||||
container_tag = "set" if obj_type is set else "frozenset"
|
||||
value = (container_tag, tuple(child_value for _, child_value in ordered_items))
|
||||
sort_key = (container_tag, tuple(child_sort for child_sort, _ in ordered_items))
|
||||
finally:
|
||||
active.discard(obj_id)
|
||||
|
||||
memo[obj_id] = (value, sort_key)
|
||||
return memo[obj_id]
|
||||
|
||||
|
||||
def _signature_to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
|
||||
"""Build the final cache-signature representation in one fail-closed pass."""
|
||||
result = _signature_to_hashable_impl(obj, budget={"remaining": max_nodes})
|
||||
if result is _FAILED_SIGNATURE:
|
||||
return Unhashable()
|
||||
return result[0]
|
||||
|
||||
|
||||
def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
|
||||
"""Convert sanitized prompt inputs into a stable hashable representation.
|
||||
|
||||
The input is expected to already be sanitized to plain built-in containers,
|
||||
but this function still fails safe for anything unexpected. Traversal is
|
||||
iterative and memoized so shared built-in substructures do not trigger
|
||||
exponential re-walks during cache-key construction.
|
||||
"""
|
||||
obj_type = type(obj)
|
||||
if obj_type in _PRIMITIVE_SIGNATURE_TYPES or obj_type is Unhashable:
|
||||
return obj
|
||||
if obj_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||
return Unhashable()
|
||||
|
||||
memo = {}
|
||||
active = set()
|
||||
snapshots = {}
|
||||
sort_memo = {}
|
||||
processed = 0
|
||||
stack = [(obj, False)]
|
||||
|
||||
def resolve_value(value):
|
||||
"""Resolve a child value from the completed memo table when available."""
|
||||
value_type = type(value)
|
||||
if value_type in _PRIMITIVE_SIGNATURE_TYPES or value_type is Unhashable:
|
||||
return value
|
||||
return memo.get(id(value), Unhashable())
|
||||
|
||||
def resolve_unordered_values(current_items, container_tag):
|
||||
"""Resolve a set-like container or fail closed if ordering is ambiguous."""
|
||||
try:
|
||||
ordered_items = [
|
||||
(_sanitized_sort_key(item, memo=sort_memo), resolve_value(item))
|
||||
for item in current_items
|
||||
]
|
||||
ordered_items.sort(key=lambda item: item[0])
|
||||
except RuntimeError:
|
||||
return Unhashable()
|
||||
|
||||
for index in range(1, len(ordered_items)):
|
||||
previous_key, previous_value = ordered_items[index - 1]
|
||||
current_key, current_value = ordered_items[index]
|
||||
if previous_key == current_key and previous_value != current_value:
|
||||
return Unhashable()
|
||||
|
||||
return (container_tag, tuple(value for _, value in ordered_items))
|
||||
|
||||
while stack:
|
||||
current, expanded = stack.pop()
|
||||
current_type = type(current)
|
||||
|
||||
if current_type in _PRIMITIVE_SIGNATURE_TYPES or current_type is Unhashable:
|
||||
continue
|
||||
if current_type not in _CONTAINER_SIGNATURE_TYPES:
|
||||
memo[id(current)] = Unhashable()
|
||||
continue
|
||||
|
||||
current_id = id(current)
|
||||
if current_id in memo:
|
||||
continue
|
||||
|
||||
if expanded:
|
||||
active.discard(current_id)
|
||||
try:
|
||||
if current_type is dict:
|
||||
items = snapshots.pop(current_id, None)
|
||||
if items is None:
|
||||
items = list(current.items())
|
||||
memo[current_id] = (
|
||||
"dict",
|
||||
tuple((resolve_value(k), resolve_value(v)) for k, v in items),
|
||||
)
|
||||
elif current_type is list:
|
||||
items = snapshots.pop(current_id, None)
|
||||
if items is None:
|
||||
items = list(current)
|
||||
memo[current_id] = ("list", tuple(resolve_value(item) for item in items))
|
||||
elif current_type is tuple:
|
||||
items = snapshots.pop(current_id, None)
|
||||
if items is None:
|
||||
items = list(current)
|
||||
memo[current_id] = ("tuple", tuple(resolve_value(item) for item in items))
|
||||
elif current_type is set:
|
||||
items = snapshots.pop(current_id, None)
|
||||
if items is None:
|
||||
items = list(current)
|
||||
memo[current_id] = resolve_unordered_values(items, "set")
|
||||
else:
|
||||
items = snapshots.pop(current_id, None)
|
||||
if items is None:
|
||||
items = list(current)
|
||||
memo[current_id] = resolve_unordered_values(items, "frozenset")
|
||||
except RuntimeError:
|
||||
memo[current_id] = Unhashable()
|
||||
continue
|
||||
|
||||
if current_id in active:
|
||||
memo[current_id] = Unhashable()
|
||||
continue
|
||||
|
||||
processed += 1
|
||||
if processed > max_nodes:
|
||||
return Unhashable()
|
||||
|
||||
active.add(current_id)
|
||||
stack.append((current, True))
|
||||
if current_type is dict:
|
||||
try:
|
||||
items = list(current.items())
|
||||
snapshots[current_id] = items
|
||||
except RuntimeError:
|
||||
memo[current_id] = Unhashable()
|
||||
active.discard(current_id)
|
||||
continue
|
||||
for key, value in reversed(items):
|
||||
stack.append((value, False))
|
||||
stack.append((key, False))
|
||||
else:
|
||||
try:
|
||||
items = list(current)
|
||||
snapshots[current_id] = items
|
||||
except RuntimeError:
|
||||
memo[current_id] = Unhashable()
|
||||
active.discard(current_id)
|
||||
continue
|
||||
for item in reversed(items):
|
||||
stack.append((item, False))
|
||||
|
||||
return memo.get(id(obj), Unhashable())
|
||||
|
||||
class CacheKeySetID(CacheKeySet):
|
||||
"""Cache-key strategy that keys nodes by node id and class type."""
|
||||
@ -238,7 +414,7 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
for ancestor_id in ancestors:
|
||||
signature.append(await self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
return to_hashable(signature)
|
||||
return _signature_to_hashable(signature)
|
||||
|
||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
"""Build the cache-signature fragment for a node's immediate inputs.
|
||||
@ -261,7 +437,7 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||
else:
|
||||
signature.append((key, _sanitize_signature_input(inputs[key])))
|
||||
signature.append((key, inputs[key]))
|
||||
return signature
|
||||
|
||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||
|
||||
@ -32,7 +32,7 @@ async def cache_control(
|
||||
)
|
||||
|
||||
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
response.headers.setdefault("Cache-Control", "no-store")
|
||||
return response
|
||||
|
||||
# Early return for non-image files - no cache headers needed
|
||||
|
||||
6
nodes.py
6
nodes.py
@ -1724,6 +1724,8 @@ class LoadImage:
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
|
||||
@ -1748,8 +1750,8 @@ class LoadImage:
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image)
|
||||
output_masks.append(mask.unsqueeze(0))
|
||||
output_images.append(image.to(dtype=dtype))
|
||||
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
|
||||
|
||||
if img.format == "MPO":
|
||||
break # ignore all frames except the first one for MPO format
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.41.19
|
||||
comfyui-frontend-package==1.41.20
|
||||
comfyui-workflow-templates==0.9.21
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
|
||||
@ -310,7 +310,7 @@ class PromptServer():
|
||||
@routes.get("/")
|
||||
async def get_root(request):
|
||||
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
||||
response.headers['Cache-Control'] = 'no-cache'
|
||||
response.headers['Cache-Control'] = 'no-store, must-revalidate'
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
return response
|
||||
|
||||
289
tests-unit/execution_test/caching_test.py
Normal file
289
tests-unit/execution_test/caching_test.py
Normal file
@ -0,0 +1,289 @@
|
||||
"""Unit tests for cache-signature canonicalization hardening."""
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _DummyNode:
|
||||
"""Minimal node stub used to satisfy cache-signature class lookups."""
|
||||
|
||||
@staticmethod
|
||||
def INPUT_TYPES():
|
||||
"""Return a minimal empty input schema for unit tests."""
|
||||
return {"required": {}}
|
||||
|
||||
|
||||
class _FakeDynPrompt:
|
||||
"""Small DynamicPrompt stand-in with only the methods these tests need."""
|
||||
|
||||
def __init__(self, nodes_by_id):
|
||||
"""Store test nodes by id."""
|
||||
self._nodes_by_id = nodes_by_id
|
||||
|
||||
def has_node(self, node_id):
|
||||
"""Return whether the fake prompt contains the requested node."""
|
||||
return node_id in self._nodes_by_id
|
||||
|
||||
def get_node(self, node_id):
|
||||
"""Return the stored node payload for the requested id."""
|
||||
return self._nodes_by_id[node_id]
|
||||
|
||||
|
||||
class _FakeIsChangedCache:
|
||||
"""Async stub for `is_changed` lookups used by cache-key generation."""
|
||||
|
||||
def __init__(self, values):
|
||||
"""Store canned `is_changed` responses keyed by node id."""
|
||||
self._values = values
|
||||
|
||||
async def get(self, node_id):
|
||||
"""Return the canned `is_changed` value for a node."""
|
||||
return self._values[node_id]
|
||||
|
||||
|
||||
class _OpaqueValue:
|
||||
"""Hashable opaque object used to exercise fail-closed unordered hashing paths."""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def caching_module(monkeypatch):
|
||||
"""Import `comfy_execution.caching` with lightweight stub dependencies."""
|
||||
torch_module = types.ModuleType("torch")
|
||||
psutil_module = types.ModuleType("psutil")
|
||||
nodes_module = types.ModuleType("nodes")
|
||||
nodes_module.NODE_CLASS_MAPPINGS = {}
|
||||
graph_module = types.ModuleType("comfy_execution.graph")
|
||||
|
||||
class DynamicPrompt:
|
||||
"""Placeholder graph type so the caching module can import cleanly."""
|
||||
|
||||
pass
|
||||
|
||||
graph_module.DynamicPrompt = DynamicPrompt
|
||||
|
||||
monkeypatch.setitem(sys.modules, "torch", torch_module)
|
||||
monkeypatch.setitem(sys.modules, "psutil", psutil_module)
|
||||
monkeypatch.setitem(sys.modules, "nodes", nodes_module)
|
||||
monkeypatch.setitem(sys.modules, "comfy_execution.graph", graph_module)
|
||||
monkeypatch.delitem(sys.modules, "comfy_execution.caching", raising=False)
|
||||
|
||||
module = importlib.import_module("comfy_execution.caching")
|
||||
module = importlib.reload(module)
|
||||
return module, nodes_module
|
||||
|
||||
|
||||
def test_signature_to_hashable_handles_shared_builtin_substructures(caching_module):
|
||||
"""Shared built-in substructures should canonicalize without collapsing to Unhashable."""
|
||||
caching, _ = caching_module
|
||||
shared = [{"value": 1}, {"value": 2}]
|
||||
|
||||
signature = caching._signature_to_hashable([shared, shared])
|
||||
|
||||
assert signature[0] == "list"
|
||||
assert signature[1][0] == signature[1][1]
|
||||
assert signature[1][0][0] == "list"
|
||||
assert signature[1][0][1][0] == ("dict", (("value", 1),))
|
||||
assert signature[1][0][1][1] == ("dict", (("value", 2),))
|
||||
|
||||
|
||||
def test_signature_to_hashable_fails_closed_on_opaque_values(caching_module):
|
||||
"""Opaque values should collapse the full signature to Unhashable immediately."""
|
||||
caching, _ = caching_module
|
||||
|
||||
signature = caching._signature_to_hashable(["safe", object()])
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
|
||||
|
||||
def test_signature_to_hashable_stops_descending_after_failure(caching_module, monkeypatch):
|
||||
"""Once canonicalization fails, later recursive descent should stop immediately."""
|
||||
caching, _ = caching_module
|
||||
original = caching._signature_to_hashable_impl
|
||||
marker = object()
|
||||
marker_seen = False
|
||||
|
||||
def tracking_canonicalize(obj, *args, **kwargs):
|
||||
"""Track whether recursion reaches the nested marker after failure."""
|
||||
nonlocal marker_seen
|
||||
if obj is marker:
|
||||
marker_seen = True
|
||||
return original(obj, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(caching, "_signature_to_hashable_impl", tracking_canonicalize)
|
||||
|
||||
signature = caching._signature_to_hashable([object(), [marker]])
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
assert marker_seen is False
|
||||
|
||||
|
||||
def test_signature_to_hashable_snapshots_list_before_recursing(caching_module, monkeypatch):
|
||||
"""List canonicalization should read a point-in-time snapshot before recursive descent."""
|
||||
caching, _ = caching_module
|
||||
original = caching._signature_to_hashable_impl
|
||||
marker = ("marker",)
|
||||
values = [marker, 2]
|
||||
|
||||
def mutating_canonicalize(obj, *args, **kwargs):
|
||||
"""Mutate the live list during recursion to verify snapshot-based traversal."""
|
||||
if obj is marker:
|
||||
values[1] = 3
|
||||
return original(obj, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize)
|
||||
|
||||
signature = caching._signature_to_hashable(values)
|
||||
|
||||
assert signature == ("list", (("tuple", ("marker",)), 2))
|
||||
assert values[1] == 3
|
||||
|
||||
|
||||
def test_signature_to_hashable_snapshots_dict_before_recursing(caching_module, monkeypatch):
|
||||
"""Dict canonicalization should read a point-in-time snapshot before recursive descent."""
|
||||
caching, _ = caching_module
|
||||
original = caching._signature_to_hashable_impl
|
||||
marker = ("marker",)
|
||||
values = {"first": marker, "second": 2}
|
||||
|
||||
def mutating_canonicalize(obj, *args, **kwargs):
|
||||
"""Mutate the live dict during recursion to verify snapshot-based traversal."""
|
||||
if obj is marker:
|
||||
values["second"] = 3
|
||||
return original(obj, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(caching, "_signature_to_hashable_impl", mutating_canonicalize)
|
||||
|
||||
signature = caching._signature_to_hashable(values)
|
||||
|
||||
assert signature == ("dict", (("first", ("tuple", ("marker",))), ("second", 2)))
|
||||
assert values["second"] == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"container_factory",
|
||||
[
|
||||
lambda marker: [marker],
|
||||
lambda marker: (marker,),
|
||||
lambda marker: {marker},
|
||||
lambda marker: frozenset({marker}),
|
||||
lambda marker: {marker: "value"},
|
||||
],
|
||||
)
|
||||
def test_signature_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
|
||||
"""Traversal RuntimeError should degrade canonicalization to Unhashable."""
|
||||
caching, _ = caching_module
|
||||
original = caching._signature_to_hashable_impl
|
||||
marker = object()
|
||||
|
||||
def raising_canonicalize(obj, *args, **kwargs):
|
||||
"""Raise a traversal RuntimeError for the marker value and delegate otherwise."""
|
||||
if obj is marker:
|
||||
raise RuntimeError("container changed during iteration")
|
||||
return original(obj, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(caching, "_signature_to_hashable_impl", raising_canonicalize)
|
||||
|
||||
signature = caching._signature_to_hashable(container_factory(marker))
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
|
||||
|
||||
def test_to_hashable_handles_shared_builtin_substructures(caching_module):
|
||||
"""The legacy helper should still hash sanitized built-ins stably when used directly."""
|
||||
caching, _ = caching_module
|
||||
shared = [{"value": 1}, {"value": 2}]
|
||||
|
||||
sanitized = [shared, shared]
|
||||
hashable = caching.to_hashable(sanitized)
|
||||
|
||||
assert hashable[0] == "list"
|
||||
assert hashable[1][0] == hashable[1][1]
|
||||
assert hashable[1][0][0] == "list"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"container_factory",
|
||||
[
|
||||
set,
|
||||
frozenset,
|
||||
],
|
||||
)
|
||||
def test_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
|
||||
"""Traversal RuntimeError should degrade unordered hash conversion to Unhashable."""
|
||||
caching, _ = caching_module
|
||||
|
||||
def raising_sort_key(obj, *args, **kwargs):
|
||||
"""Raise a traversal RuntimeError while unordered values are canonicalized."""
|
||||
raise RuntimeError("container changed during iteration")
|
||||
|
||||
monkeypatch.setattr(caching, "_sanitized_sort_key", raising_sort_key)
|
||||
|
||||
hashable = caching.to_hashable(container_factory({"value"}))
|
||||
|
||||
assert isinstance(hashable, caching.Unhashable)
|
||||
|
||||
|
||||
def test_signature_to_hashable_fails_closed_for_ambiguous_dict_ordering(caching_module):
|
||||
"""Ambiguous dict sort ties should fail closed instead of depending on input order."""
|
||||
caching, _ = caching_module
|
||||
ambiguous = {
|
||||
_OpaqueValue(): _OpaqueValue(),
|
||||
_OpaqueValue(): _OpaqueValue(),
|
||||
}
|
||||
|
||||
sanitized = caching._signature_to_hashable(ambiguous)
|
||||
|
||||
assert isinstance(sanitized, caching.Unhashable)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"container_factory",
|
||||
[
|
||||
set,
|
||||
frozenset,
|
||||
],
|
||||
)
|
||||
def test_to_hashable_fails_closed_for_ambiguous_unordered_values(caching_module, container_factory):
|
||||
"""Ambiguous unordered values should fail closed instead of depending on iteration order."""
|
||||
caching, _ = caching_module
|
||||
container = container_factory({_OpaqueValue(), _OpaqueValue()})
|
||||
|
||||
hashable = caching.to_hashable(container)
|
||||
|
||||
assert isinstance(hashable, caching.Unhashable)
|
||||
|
||||
|
||||
def test_get_node_signature_returns_top_level_unhashable_for_tainted_signature(caching_module, monkeypatch):
|
||||
"""Tainted full signatures should fail closed before `to_hashable()` runs."""
|
||||
caching, nodes_module = caching_module
|
||||
monkeypatch.setitem(nodes_module.NODE_CLASS_MAPPINGS, "UnitTestNode", _DummyNode)
|
||||
monkeypatch.setattr(
|
||||
caching,
|
||||
"to_hashable",
|
||||
lambda *_args, **_kwargs: pytest.fail("to_hashable should not run for tainted signatures"),
|
||||
)
|
||||
|
||||
is_changed_value = []
|
||||
is_changed_value.append(is_changed_value)
|
||||
|
||||
dynprompt = _FakeDynPrompt(
|
||||
{
|
||||
"node": {
|
||||
"class_type": "UnitTestNode",
|
||||
"inputs": {"value": 5},
|
||||
}
|
||||
}
|
||||
)
|
||||
key_set = caching.CacheKeySetInputSignature(
|
||||
dynprompt,
|
||||
["node"],
|
||||
_FakeIsChangedCache({"node": is_changed_value}),
|
||||
)
|
||||
|
||||
signature = asyncio.run(key_set.get_node_signature(dynprompt, "node"))
|
||||
|
||||
assert isinstance(signature, caching.Unhashable)
|
||||
@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
|
||||
},
|
||||
# JavaScript/CSS scenarios
|
||||
{
|
||||
"name": "js_no_cache",
|
||||
"name": "js_no_store",
|
||||
"path": "/script.js",
|
||||
"status": 200,
|
||||
"expected_cache": "no-cache",
|
||||
"expected_cache": "no-store",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "css_no_cache",
|
||||
"name": "css_no_store",
|
||||
"path": "/styles.css",
|
||||
"status": 200,
|
||||
"expected_cache": "no-cache",
|
||||
"expected_cache": "no-store",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "index_json_no_cache",
|
||||
"name": "index_json_no_store",
|
||||
"path": "/api/index.json",
|
||||
"status": 200,
|
||||
"expected_cache": "no-cache",
|
||||
"expected_cache": "no-store",
|
||||
"should_have_header": True,
|
||||
},
|
||||
{
|
||||
"name": "localized_index_json_no_cache",
|
||||
"name": "localized_index_json_no_store",
|
||||
"path": "/templates/index.zh.json",
|
||||
"status": 200,
|
||||
"expected_cache": "no-cache",
|
||||
"expected_cache": "no-store",
|
||||
"should_have_header": True,
|
||||
},
|
||||
# Non-matching files
|
||||
|
||||
Loading…
Reference in New Issue
Block a user