mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
Inference tests
This commit is contained in:
parent
4028c1663b
commit
85772d450d
@ -1019,7 +1019,7 @@ class LTXV(BaseModel):
|
||||
|
||||
|
||||
class HunyuanVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, image_model=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=HunyuanVideoModel)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
|
||||
@ -228,7 +228,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
||||
io.Image.Input("image"),
|
||||
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||
io.Int.Input("resolution_steps", default=1, min=1, max=256),
|
||||
io.Int.Input("resolution_steps", default=1, min=1, max=256, optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
@ -236,7 +236,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
|
||||
def execute(cls, image, upscale_method, megapixels, resolution_steps=1) -> io.NodeOutput:
|
||||
samples = image.movedim(-1, 1)
|
||||
total = megapixels * 1024 * 1024
|
||||
|
||||
|
||||
@ -143,6 +143,7 @@ dev = [
|
||||
"coverage",
|
||||
"pylint",
|
||||
"astroid",
|
||||
"nvidia-ml-py",
|
||||
]
|
||||
|
||||
cpu = [
|
||||
|
||||
@ -4,6 +4,13 @@ import logging
|
||||
import time
|
||||
from importlib.abc import Traversable
|
||||
from typing import Any, AsyncGenerator
|
||||
import threading
|
||||
import psutil
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
except ImportError:
|
||||
pynvml = None
|
||||
|
||||
import pytest
|
||||
|
||||
@ -21,6 +28,84 @@ from comfy.cli_args_types import PerformanceFeature
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResourceMonitor:
|
||||
def __init__(self, interval: float = 0.1):
|
||||
self.interval = interval
|
||||
self.peak_cpu_ram = 0
|
||||
self.peak_gpu_vram = 0
|
||||
self._stop_event = threading.Event()
|
||||
self._thread = None
|
||||
self._pynvml_available = False
|
||||
self._gpu_handles = []
|
||||
|
||||
def _monitor(self):
|
||||
current_process = psutil.Process()
|
||||
while not self._stop_event.is_set():
|
||||
# Monitor CPU RAM (RSS) for process tree
|
||||
try:
|
||||
children = current_process.children(recursive=True)
|
||||
processes = [current_process] + children
|
||||
pids = {p.pid for p in processes}
|
||||
|
||||
total_rss = 0
|
||||
for p in processes:
|
||||
try:
|
||||
total_rss += p.memory_info().rss
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
self.peak_cpu_ram = max(self.peak_cpu_ram, total_rss)
|
||||
|
||||
# Monitor GPU VRAM if available
|
||||
if self._pynvml_available and self._gpu_handles:
|
||||
total_vram = 0
|
||||
try:
|
||||
# Iterate over all detected GPUs
|
||||
for handle in self._gpu_handles:
|
||||
# Get all processes running on the GPU
|
||||
try:
|
||||
compute_procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
||||
graphics_procs = pynvml.nvmlDeviceGetGraphicsRunningProcesses(handle)
|
||||
|
||||
# Filter for our process tree
|
||||
for p in compute_procs + graphics_procs:
|
||||
if p.pid in pids:
|
||||
total_vram += p.usedGpuMemory
|
||||
except Exception:
|
||||
pass # Skip errors for specific GPU queries
|
||||
|
||||
self.peak_gpu_vram = max(self.peak_gpu_vram, total_vram)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time.sleep(self.interval)
|
||||
|
||||
def __enter__(self):
|
||||
if pynvml:
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
device_count = pynvml.nvmlDeviceGetCount()
|
||||
self._gpu_handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in range(device_count)]
|
||||
self._pynvml_available = True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize pynvml for VRAM monitoring: {e}")
|
||||
|
||||
self._thread = threading.Thread(target=self._monitor, daemon=True)
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._stop_event.set()
|
||||
if self._thread:
|
||||
self._thread.join()
|
||||
if self._pynvml_available:
|
||||
try:
|
||||
pynvml.nvmlShutdown()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _generate_config_params():
|
||||
attn_keys = [
|
||||
"use_pytorch_cross_attention",
|
||||
@ -94,13 +179,18 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
|
||||
outputs = {}
|
||||
|
||||
start_time = time.time()
|
||||
monitor = ResourceMonitor()
|
||||
try:
|
||||
outputs = await client.queue_prompt(prompt)
|
||||
with monitor:
|
||||
outputs = await client.queue_prompt(prompt)
|
||||
except TorchAudioNotFoundError:
|
||||
pytest.skip("requires torchaudio")
|
||||
finally:
|
||||
end_time = time.time()
|
||||
logger.info(f"Test {workflow_name} with client {client} took {end_time - start_time:.4f}s")
|
||||
duration = end_time - start_time
|
||||
ram_gb = monitor.peak_cpu_ram / (1024**3)
|
||||
vram_gb = monitor.peak_gpu_vram / (1024**3)
|
||||
logger.info(f"Test {workflow_name} with client {client} took {duration:.4f}s | Peak RAM: {ram_gb:.2f} GB | Peak VRAM: {vram_gb:.2f} GB")
|
||||
|
||||
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user