mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 05:40:49 +08:00
Inference tests
This commit is contained in:
parent
4028c1663b
commit
85772d450d
@ -1019,7 +1019,7 @@ class LTXV(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class HunyuanVideo(BaseModel):
|
class HunyuanVideo(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, image_model=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=HunyuanVideoModel)
|
super().__init__(model_config, model_type, device=device, unet_model=HunyuanVideoModel)
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
|
|||||||
@ -228,7 +228,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
|||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
io.Combo.Input("upscale_method", options=cls.upscale_methods),
|
||||||
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
|
||||||
io.Int.Input("resolution_steps", default=1, min=1, max=256),
|
io.Int.Input("resolution_steps", default=1, min=1, max=256, optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Image.Output(),
|
io.Image.Output(),
|
||||||
@ -236,7 +236,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
|
def execute(cls, image, upscale_method, megapixels, resolution_steps=1) -> io.NodeOutput:
|
||||||
samples = image.movedim(-1, 1)
|
samples = image.movedim(-1, 1)
|
||||||
total = megapixels * 1024 * 1024
|
total = megapixels * 1024 * 1024
|
||||||
|
|
||||||
|
|||||||
@ -143,6 +143,7 @@ dev = [
|
|||||||
"coverage",
|
"coverage",
|
||||||
"pylint",
|
"pylint",
|
||||||
"astroid",
|
"astroid",
|
||||||
|
"nvidia-ml-py",
|
||||||
]
|
]
|
||||||
|
|
||||||
cpu = [
|
cpu = [
|
||||||
|
|||||||
@ -4,6 +4,13 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from importlib.abc import Traversable
|
from importlib.abc import Traversable
|
||||||
from typing import Any, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
|
import threading
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pynvml
|
||||||
|
except ImportError:
|
||||||
|
pynvml = None
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -21,6 +28,84 @@ from comfy.cli_args_types import PerformanceFeature
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceMonitor:
|
||||||
|
def __init__(self, interval: float = 0.1):
|
||||||
|
self.interval = interval
|
||||||
|
self.peak_cpu_ram = 0
|
||||||
|
self.peak_gpu_vram = 0
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
self._thread = None
|
||||||
|
self._pynvml_available = False
|
||||||
|
self._gpu_handles = []
|
||||||
|
|
||||||
|
def _monitor(self):
|
||||||
|
current_process = psutil.Process()
|
||||||
|
while not self._stop_event.is_set():
|
||||||
|
# Monitor CPU RAM (RSS) for process tree
|
||||||
|
try:
|
||||||
|
children = current_process.children(recursive=True)
|
||||||
|
processes = [current_process] + children
|
||||||
|
pids = {p.pid for p in processes}
|
||||||
|
|
||||||
|
total_rss = 0
|
||||||
|
for p in processes:
|
||||||
|
try:
|
||||||
|
total_rss += p.memory_info().rss
|
||||||
|
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||||
|
pass
|
||||||
|
self.peak_cpu_ram = max(self.peak_cpu_ram, total_rss)
|
||||||
|
|
||||||
|
# Monitor GPU VRAM if available
|
||||||
|
if self._pynvml_available and self._gpu_handles:
|
||||||
|
total_vram = 0
|
||||||
|
try:
|
||||||
|
# Iterate over all detected GPUs
|
||||||
|
for handle in self._gpu_handles:
|
||||||
|
# Get all processes running on the GPU
|
||||||
|
try:
|
||||||
|
compute_procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
|
||||||
|
graphics_procs = pynvml.nvmlDeviceGetGraphicsRunningProcesses(handle)
|
||||||
|
|
||||||
|
# Filter for our process tree
|
||||||
|
for p in compute_procs + graphics_procs:
|
||||||
|
if p.pid in pids:
|
||||||
|
total_vram += p.usedGpuMemory
|
||||||
|
except Exception:
|
||||||
|
pass # Skip errors for specific GPU queries
|
||||||
|
|
||||||
|
self.peak_gpu_vram = max(self.peak_gpu_vram, total_vram)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
time.sleep(self.interval)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if pynvml:
|
||||||
|
try:
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
device_count = pynvml.nvmlDeviceGetCount()
|
||||||
|
self._gpu_handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in range(device_count)]
|
||||||
|
self._pynvml_available = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to initialize pynvml for VRAM monitoring: {e}")
|
||||||
|
|
||||||
|
self._thread = threading.Thread(target=self._monitor, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self._stop_event.set()
|
||||||
|
if self._thread:
|
||||||
|
self._thread.join()
|
||||||
|
if self._pynvml_available:
|
||||||
|
try:
|
||||||
|
pynvml.nvmlShutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _generate_config_params():
|
def _generate_config_params():
|
||||||
attn_keys = [
|
attn_keys = [
|
||||||
"use_pytorch_cross_attention",
|
"use_pytorch_cross_attention",
|
||||||
@ -94,13 +179,18 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
|
|||||||
outputs = {}
|
outputs = {}
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
monitor = ResourceMonitor()
|
||||||
try:
|
try:
|
||||||
outputs = await client.queue_prompt(prompt)
|
with monitor:
|
||||||
|
outputs = await client.queue_prompt(prompt)
|
||||||
except TorchAudioNotFoundError:
|
except TorchAudioNotFoundError:
|
||||||
pytest.skip("requires torchaudio")
|
pytest.skip("requires torchaudio")
|
||||||
finally:
|
finally:
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.info(f"Test {workflow_name} with client {client} took {end_time - start_time:.4f}s")
|
duration = end_time - start_time
|
||||||
|
ram_gb = monitor.peak_cpu_ram / (1024**3)
|
||||||
|
vram_gb = monitor.peak_gpu_vram / (1024**3)
|
||||||
|
logger.info(f"Test {workflow_name} with client {client} took {duration:.4f}s | Peak RAM: {ram_gb:.2f} GB | Peak VRAM: {vram_gb:.2f} GB")
|
||||||
|
|
||||||
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
if any(v.class_type == "SaveImage" for v in prompt.values()):
|
||||||
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user