Compare commits

...

6 Commits

Author SHA1 Message Date
Christian Byrne
05a37fd2e0
Merge b6b420190c into de9ada6a41 2026-02-03 00:33:34 +01:00
rattus
de9ada6a41
Dynamic VRAM unloading fix (#12227)
* mp: fix full dynamic unloading

This was not unloading dynamic models when requesting a full unload via
the unpatch() code path.

This was ok, i your workflow was all dynamic models but fails with big
VRAM leaks if you need to fully unload something for a regular ModelPatcher

It also fices the "unload models" button.

* mm: load models outside of Aimdo Mempool

In dynamic_vram mode, escape the Aimdo mempool and load into the regular
mempool. Use a dummy thread to do it.
2026-02-02 17:35:20 -05:00
rattus
37f711d4a1
mm: Fix cast buffers with intel offloading (#12229)
Intel has offloading support but there were some nvidia calls in the
new cast buffer stuff.
2026-02-02 17:34:46 -05:00
bymyself
b6b420190c fix: only add timestamps to browser-previewed outputs
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Reverted timestamp addition for non-previewed files:
- SaveLatent (.latent) - not previewed in browser
- CheckpointSave, CLIPSave, VAESave (.safetensors) - model files
- ExtractAndSaveLoRA, SaveLoRA (.safetensors) - LoRA files

Kept timestamps for browser-previewed content:
- Images (PNG, SVG)
- Videos (WebM, MP4)
- Audio
- 3D models (GLB)

Amp-Thread-ID: https://ampcode.com/threads/T-019c1be0-7238-71ec-9c0b-2d4468d61202
2026-02-01 18:03:41 -08:00
bymyself
6c2223ade9 fix: convert tests to unittest, remove unused import
Amp-Thread-ID: https://ampcode.com/threads/T-019c17ed-fd96-71ed-8055-83a8cd6f8f2b
2026-01-31 22:41:33 -08:00
bymyself
0f259cabdd feat: add timestamp to output filenames for cache-busting
Add get_timestamp() and format_output_filename() utilities to folder_paths.py
that generate unique filenames with UTC timestamps. This eliminates the need
for client-side cache-busting query parameters.

New filename format: prefix_00001_20260131-220945-123456_.ext

Updated all save nodes to use the new format:
- nodes.py (SaveImage, SaveLatent, SaveImageWebsocket)
- comfy_api/latest/_ui.py (UILatent)
- comfy_extras/nodes_video.py (SaveWEBM, SaveAnimatedPNG, SaveAnimatedWEBP)
- comfy_extras/nodes_images.py (SaveSVG)
- comfy_extras/nodes_hunyuan3d.py (Save3D)
- comfy_extras/nodes_model_merging.py (SaveCheckpointSimple)
- comfy_extras/nodes_lora_extract.py (LoraSave)
- comfy_extras/nodes_train.py (SaveEmbedding)

Amp-Thread-ID: https://ampcode.com/threads/T-019c17e5-1c0a-736f-970d-e411aae222fc
2026-01-31 22:30:57 -08:00
9 changed files with 166 additions and 27 deletions

View File

@ -19,7 +19,8 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args, PerformanceFeature
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import threading
import torch
import sys
import platform
@ -650,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
soft_empty_cache()
return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@ -746,8 +747,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
current_loaded_models.insert(0, loaded_model)
return
def load_model_gpu(model):
return load_models_gpu([model])
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
with torch.inference_mode():
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
soft_empty_cache()
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
#thread local. So exploit that to escape context
if enables_dynamic_vram():
t = threading.Thread(
target=load_models_gpu_thread,
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
)
t.start()
t.join()
else:
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
def loaded_models(only_currently_used=False):
output = []
@ -1112,11 +1130,11 @@ def get_cast_buffer(offload_stream, device, size, ref):
return None
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
torch.cuda.synchronize()
synchronize()
del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
torch.cuda.empty_cache()
soft_empty_cache()
with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
@ -1132,9 +1150,7 @@ def reset_cast_buffers():
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
STREAM_CAST_BUFFERS.clear()
if comfy.memory_management.aimdo_allocator is None:
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
torch.cuda.empty_cache()
soft_empty_cache()
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
@ -1284,7 +1300,7 @@ def discard_cuda_async_error():
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
torch.cuda.synchronize()
synchronize()
except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return
pass
@ -1688,6 +1704,12 @@ def lora_compute_dtype(device):
LORA_COMPUTE_DTYPES[device] = dtype
return dtype
def synchronize():
if is_intel_xpu():
torch.xpu.synchronize()
elif torch.cuda.is_available():
torch.cuda.synchronize()
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
@ -1713,9 +1735,6 @@ def debug_memory_summary():
return torch.cuda.memory.memory_summary()
return ""
#TODO: might be cleaner to put this somewhere else
import threading
class InterruptProcessingException(Exception):
pass

View File

@ -1597,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher):
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None)
self.partially_unload(None, 1e32)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above

View File

@ -146,8 +146,7 @@ class ImageSaveHelper:
metadata = ImageSaveHelper._create_png_metadata(cls)
for batch_number, image_tensor in enumerate(images):
img = ImageSaveHelper._convert_tensor_to_pil(image_tensor)
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
file = folder_paths.format_output_filename(filename, counter, "png", batch_num=str(batch_number))
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level)
results.append(SavedResult(file, subfolder, folder_type))
counter += 1
@ -176,7 +175,7 @@ class ImageSaveHelper:
)
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
metadata = ImageSaveHelper._create_animated_png_metadata(cls)
file = f"{filename}_{counter:05}_.png"
file = folder_paths.format_output_filename(filename, counter, "png")
save_path = os.path.join(full_output_folder, file)
pil_images[0].save(
save_path,
@ -220,7 +219,7 @@ class ImageSaveHelper:
)
pil_images = [ImageSaveHelper._convert_tensor_to_pil(img) for img in images]
pil_exif = ImageSaveHelper._create_webp_metadata(pil_images[0], cls)
file = f"{filename}_{counter:05}_.webp"
file = folder_paths.format_output_filename(filename, counter, "webp")
pil_images[0].save(
os.path.join(full_output_folder, file),
save_all=True,
@ -284,8 +283,7 @@ class AudioSaveHelper:
results = []
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
file = folder_paths.format_output_filename(filename, counter, format, batch_num=str(batch_number))
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially

View File

@ -642,7 +642,7 @@ class SaveGLB(IO.ComfyNode):
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
for i in range(mesh.vertices.shape[0]):
f = f"{filename}_{counter:05}_.glb"
f = folder_paths.format_output_filename(filename, counter, "glb")
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
results.append({
"filename": f,

View File

@ -460,8 +460,7 @@ class SaveSVGNode(IO.ComfyNode):
for batch_number, svg_bytes in enumerate(svg.data):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.svg"
file = folder_paths.format_output_filename(filename, counter, "svg", batch_num=str(batch_number))
# Read SVG content
svg_bytes.seek(0)

View File

@ -36,7 +36,7 @@ class SaveWEBM(io.ComfyNode):
filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
)
file = f"{filename}_{counter:05}_.webm"
file = folder_paths.format_output_filename(filename, counter, "webm")
container = av.open(os.path.join(full_output_folder, file), mode="w")
if cls.hidden.prompt is not None:
@ -102,7 +102,7 @@ class SaveVideo(io.ComfyNode):
metadata["prompt"] = cls.hidden.prompt
if len(metadata) > 0:
saved_metadata = metadata
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
file = folder_paths.format_output_filename(filename, counter, Types.VideoContainer.get_extension(format))
video.save_to(
os.path.join(full_output_folder, file),
format=Types.VideoContainer(format),

View File

@ -4,6 +4,7 @@ import os
import time
import mimetypes
import logging
from datetime import datetime, timezone
from typing import Literal, List
from collections.abc import Collection
@ -11,6 +12,46 @@ from comfy.cli_args import args
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
def get_timestamp() -> str:
"""Generate a filesystem-safe timestamp string for output filenames.
Returns a UTC timestamp in the format YYYYMMDD-HHMMSS-ffffff (microseconds)
which is human-readable, lexicographically sortable, and Windows-safe.
"""
now = datetime.now(timezone.utc)
return now.strftime("%Y%m%d-%H%M%S-%f")
def format_output_filename(
filename: str,
counter: int,
ext: str,
*,
batch_num: str | None = None,
timestamp: str | None = None,
) -> str:
"""Format an output filename with counter and timestamp for cache-busting.
Args:
filename: The base filename prefix
counter: The numeric counter for uniqueness
ext: The file extension (with or without leading dot)
batch_num: Optional batch number to replace %batch_num% placeholder
timestamp: Optional timestamp string (defaults to current UTC time)
Returns:
Formatted filename like: filename_00001_20260131-123456-789012_.ext
"""
ext = ext.lstrip(".")
if timestamp is None:
timestamp = get_timestamp()
if batch_num is not None:
filename = filename.replace("%batch_num%", batch_num)
return f"{filename}_{counter:05}_{timestamp}_.{ext}"
folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}
# --base-directory - Resets all default paths configured in folder_paths with a new base path

View File

@ -1667,8 +1667,7 @@ class SaveImage:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
file = folder_paths.format_output_filename(filename, counter, "png", batch_num=str(batch_number))
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
results.append({
"filename": file,

View File

@ -0,0 +1,83 @@
"""Tests for folder_paths.format_output_filename and get_timestamp functions."""
import sys
import os
import unittest
# Add the ComfyUI root to the path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import folder_paths
class TestGetTimestamp(unittest.TestCase):
"""Tests for get_timestamp function."""
def test_returns_string(self):
"""Should return a string."""
result = folder_paths.get_timestamp()
self.assertIsInstance(result, str)
def test_format_matches_expected_pattern(self):
"""Should return format YYYYMMDD-HHMMSS-ffffff."""
result = folder_paths.get_timestamp()
# Pattern: 8 digits, hyphen, 6 digits, hyphen, 6 digits
pattern = r"^\d{8}-\d{6}-\d{6}$"
self.assertRegex(result, pattern)
def test_is_filesystem_safe(self):
"""Should not contain characters that are unsafe for filenames."""
result = folder_paths.get_timestamp()
unsafe_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|', ' ']
for char in unsafe_chars:
self.assertNotIn(char, result)
class TestFormatOutputFilename(unittest.TestCase):
"""Tests for format_output_filename function."""
def test_basic_format(self):
"""Should format filename with counter and timestamp."""
result = folder_paths.format_output_filename("test", 1, "png")
# Pattern: test_00001_YYYYMMDD-HHMMSS-ffffff_.png
pattern = r"^test_00001_\d{8}-\d{6}-\d{6}_\.png$"
self.assertRegex(result, pattern)
def test_counter_padding(self):
"""Should pad counter to 5 digits."""
result = folder_paths.format_output_filename("test", 42, "png")
self.assertIn("_00042_", result)
def test_extension_with_leading_dot(self):
"""Should handle extension with leading dot."""
result = folder_paths.format_output_filename("test", 1, ".png")
self.assertTrue(result.endswith("_.png"))
self.assertNotIn("..png", result)
def test_extension_without_leading_dot(self):
"""Should handle extension without leading dot."""
result = folder_paths.format_output_filename("test", 1, "webm")
self.assertTrue(result.endswith("_.webm"))
def test_batch_num_replacement(self):
"""Should replace %batch_num% placeholder."""
result = folder_paths.format_output_filename("test_%batch_num%", 1, "png", batch_num="3")
self.assertIn("test_3_", result)
self.assertNotIn("%batch_num%", result)
def test_custom_timestamp(self):
"""Should use provided timestamp instead of generating one."""
custom_ts = "20260101-120000-000000"
result = folder_paths.format_output_filename("test", 1, "png", timestamp=custom_ts)
self.assertIn(custom_ts, result)
def test_different_extensions(self):
"""Should work with various extensions."""
extensions = ["png", "webp", "webm", "svg", "glb", "safetensors", "latent"]
for ext in extensions:
result = folder_paths.format_output_filename("test", 1, ext)
self.assertTrue(result.endswith(f"_.{ext}"))
if __name__ == "__main__":
unittest.main()