mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
213 lines
8.3 KiB
Python
213 lines
8.3 KiB
Python
import importlib.resources
|
|
import json
|
|
import logging
|
|
import time
|
|
from importlib.abc import Traversable
|
|
from typing import Any, AsyncGenerator
|
|
import threading
|
|
import psutil
|
|
|
|
try:
|
|
import pynvml
|
|
except ImportError:
|
|
pynvml = None
|
|
|
|
import pytest
|
|
|
|
from comfy.api.components.schema.prompt import Prompt
|
|
from comfy.client.embedded_comfy_client import Comfy
|
|
from comfy.distributed.process_pool_executor import ProcessPoolExecutor
|
|
from comfy.model_downloader import add_known_models, KNOWN_LORAS
|
|
from comfy.model_downloader_types import CivitFile, HuggingFile
|
|
from comfy_extras.nodes.nodes_audio import TorchAudioNotFoundError
|
|
from . import workflows
|
|
import itertools
|
|
from comfy.cli_args import default_configuration
|
|
from comfy.cli_args_types import PerformanceFeature
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ResourceMonitor:
|
|
def __init__(self, interval: float = 0.1):
|
|
self.interval = interval
|
|
self.peak_cpu_ram = 0
|
|
self.peak_gpu_vram = 0
|
|
self._stop_event = threading.Event()
|
|
self._thread = None
|
|
self._pynvml_available = False
|
|
self._gpu_handles = []
|
|
|
|
def _monitor(self):
|
|
current_process = psutil.Process()
|
|
while not self._stop_event.is_set():
|
|
# Monitor CPU RAM (RSS) for process tree
|
|
try:
|
|
children = current_process.children(recursive=True)
|
|
processes = [current_process] + children
|
|
pids = {p.pid for p in processes}
|
|
|
|
total_rss = 0
|
|
for p in processes:
|
|
try:
|
|
total_rss += p.memory_info().rss
|
|
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
|
pass
|
|
self.peak_cpu_ram = max(self.peak_cpu_ram, total_rss)
|
|
|
|
# Monitor GPU VRAM if available
|
|
if self._pynvml_available and self._gpu_handles:
|
|
total_vram = 0
|
|
try:
|
|
# Iterate over all detected GPUs
|
|
for handle in self._gpu_handles:
|
|
# Get all processes running on the GPU
|
|
try:
|
|
compute_procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
|
graphics_procs = pynvml.nvmlDeviceGetGraphicsRunningProcesses(handle)
|
|
|
|
# Filter for our process tree
|
|
for p in compute_procs + graphics_procs:
|
|
if p.pid in pids:
|
|
total_vram += p.usedGpuMemory
|
|
except Exception:
|
|
pass # Skip errors for specific GPU queries
|
|
|
|
self.peak_gpu_vram = max(self.peak_gpu_vram, total_vram)
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
|
|
time.sleep(self.interval)
|
|
|
|
def __enter__(self):
|
|
if pynvml:
|
|
try:
|
|
pynvml.nvmlInit()
|
|
device_count = pynvml.nvmlDeviceGetCount()
|
|
self._gpu_handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in range(device_count)]
|
|
self._pynvml_available = True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to initialize pynvml for VRAM monitoring: {e}")
|
|
|
|
self._thread = threading.Thread(target=self._monitor, daemon=True)
|
|
self._thread.start()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self._stop_event.set()
|
|
if self._thread:
|
|
self._thread.join()
|
|
if self._pynvml_available:
|
|
try:
|
|
pynvml.nvmlShutdown()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _generate_config_params():
|
|
attn_keys = [
|
|
"use_pytorch_cross_attention",
|
|
# "use_split_cross_attention",
|
|
# "use_quad_cross_attention",
|
|
# "use_sage_attention",
|
|
# "use_flash_attention"
|
|
]
|
|
attn_options = [
|
|
{k: (k == target_key) for k in attn_keys}
|
|
for target_key in attn_keys
|
|
]
|
|
|
|
async_options = [
|
|
{"disable_async_offload": False},
|
|
{"disable_async_offload": True},
|
|
]
|
|
pinned_options = [
|
|
{"disable_pinned_memory": False},
|
|
{"disable_pinned_memory": True},
|
|
]
|
|
fast_options = [
|
|
{"fast": set()},
|
|
# {"fast": {PerformanceFeature.Fp16Accumulation}},
|
|
# {"fast": {PerformanceFeature.Fp8MatrixMultiplication}},
|
|
{"fast": {PerformanceFeature.CublasOps}},
|
|
]
|
|
|
|
for attn, asnc, pinned, fst in itertools.product(attn_options, async_options, pinned_options, fast_options):
|
|
config_update = {}
|
|
config_update.update(attn)
|
|
config_update.update(asnc)
|
|
config_update.update(pinned)
|
|
config_update.update(fst)
|
|
yield config_update
|
|
|
|
|
|
@pytest.fixture(scope="function", autouse=False, params=_generate_config_params(), ids=lambda p: ",".join(f"{k}={v}" for k, v in p.items()))
|
|
async def client(tmp_path_factory, request) -> AsyncGenerator[Any, Any]:
|
|
config = default_configuration()
|
|
# this should help things go a little faster
|
|
config.disable_all_custom_nodes = True
|
|
config.update(request.param)
|
|
# use ProcessPoolExecutor to respect various config settings
|
|
with ProcessPoolExecutor(max_workers=1) as executor:
|
|
async with Comfy(configuration=config, executor=executor) as client:
|
|
yield client
|
|
|
|
|
|
def _prepare_for_workflows() -> dict[str, Traversable]:
|
|
|
|
add_known_models("loras", HuggingFile("artificialguybr/pixelartredmond-1-5v-pixel-art-loras-for-sd-1-5", "PixelArtRedmond15V-PixelArt-PIXARFK.safetensors"))
|
|
add_known_models("checkpoints", HuggingFile("autismanon/modeldump", "cardosAnime_v20.safetensors"))
|
|
|
|
return {f.name: f for f in importlib.resources.files(workflows).iterdir() if f.is_file() and f.name.endswith(".json")}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("workflow_name, workflow_file", _prepare_for_workflows().items())
|
|
async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu: bool, client: Comfy):
|
|
if not has_gpu:
|
|
pytest.skip("requires gpu")
|
|
|
|
if "compile" in workflow_name:
|
|
pytest.skip("compilation has regressed in 0.4.0 and later because upcast weights are now permitted to be compiled, causing OOM errors in most cases")
|
|
return
|
|
|
|
workflow = json.loads(workflow_file.read_text(encoding="utf8"))
|
|
|
|
prompt = Prompt.validate(workflow)
|
|
# todo: add all the models we want to test a bit m2ore elegantly
|
|
outputs = {}
|
|
|
|
start_time = time.time()
|
|
monitor = ResourceMonitor()
|
|
try:
|
|
with monitor:
|
|
outputs = await client.queue_prompt(prompt)
|
|
except TorchAudioNotFoundError:
|
|
pytest.skip("requires torchaudio")
|
|
finally:
|
|
end_time = time.time()
|
|
duration = end_time - start_time
|
|
ram_gb = monitor.peak_cpu_ram / (1024**3)
|
|
vram_gb = monitor.peak_gpu_vram / (1024**3)
|
|
logger.info(f"Test {workflow_name} with client {client} took {duration:.4f}s | Peak RAM: {ram_gb:.2f} GB | Peak VRAM: {vram_gb:.2f} GB")
|
|
|
|
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
|
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
|
assert outputs[save_image_node_id]["images"][0]["abs_path"] is not None
|
|
elif any(v.class_type == "SaveAudio" for v in prompt.values()):
|
|
save_audio_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAudio")
|
|
assert outputs[save_audio_node_id]["audio"][0]["filename"] is not None
|
|
elif any(v.class_type == "SaveAnimatedWEBP" for v in prompt.values()):
|
|
save_video_node_id = next(key for key in prompt if prompt[key].class_type == "SaveAnimatedWEBP")
|
|
assert outputs[save_video_node_id]["images"][0]["filename"] is not None
|
|
elif any(v.class_type == "PreviewString" for v in prompt.values()):
|
|
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "PreviewString")
|
|
output_str = outputs[save_image_node_id]["string"][0]
|
|
assert output_str is not None
|
|
assert len(output_str) > 0
|
|
else:
|
|
assert len(outputs) > 0
|
|
logger.warning(f"test {workflow_name} did not have a node that could be checked for output")
|