Inference tests

This commit is contained in:
doctorpangloss 2025-12-26 16:28:54 -08:00
parent 4028c1663b
commit 85772d450d
4 changed files with 96 additions and 5 deletions

View File

@ -1019,7 +1019,7 @@ class LTXV(BaseModel):
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, image_model=None):
super().__init__(model_config, model_type, device=device, unet_model=HunyuanVideoModel)
def encode_adm(self, **kwargs):

View File

@ -228,7 +228,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
io.Image.Input("image"),
io.Combo.Input("upscale_method", options=cls.upscale_methods),
io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01),
io.Int.Input("resolution_steps", default=1, min=1, max=256),
io.Int.Input("resolution_steps", default=1, min=1, max=256, optional=True),
],
outputs=[
io.Image.Output(),
@ -236,7 +236,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
)
@classmethod
def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput:
def execute(cls, image, upscale_method, megapixels, resolution_steps=1) -> io.NodeOutput:
samples = image.movedim(-1, 1)
total = megapixels * 1024 * 1024

View File

@ -143,6 +143,7 @@ dev = [
"coverage",
"pylint",
"astroid",
"nvidia-ml-py",
]
cpu = [

View File

@ -4,6 +4,13 @@ import logging
import time
from importlib.abc import Traversable
from typing import Any, AsyncGenerator
import threading
import psutil
try:
import pynvml
except ImportError:
pynvml = None
import pytest
@ -21,6 +28,84 @@ from comfy.cli_args_types import PerformanceFeature
logger = logging.getLogger(__name__)
class ResourceMonitor:
def __init__(self, interval: float = 0.1):
self.interval = interval
self.peak_cpu_ram = 0
self.peak_gpu_vram = 0
self._stop_event = threading.Event()
self._thread = None
self._pynvml_available = False
self._gpu_handles = []
def _monitor(self):
current_process = psutil.Process()
while not self._stop_event.is_set():
# Monitor CPU RAM (RSS) for process tree
try:
children = current_process.children(recursive=True)
processes = [current_process] + children
pids = {p.pid for p in processes}
total_rss = 0
for p in processes:
try:
total_rss += p.memory_info().rss
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
self.peak_cpu_ram = max(self.peak_cpu_ram, total_rss)
# Monitor GPU VRAM if available
if self._pynvml_available and self._gpu_handles:
total_vram = 0
try:
# Iterate over all detected GPUs
for handle in self._gpu_handles:
# Get all processes running on the GPU
try:
compute_procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
graphics_procs = pynvml.nvmlDeviceGetGraphicsRunningProcesses(handle)
# Filter for our process tree
for p in compute_procs + graphics_procs:
if p.pid in pids:
total_vram += p.usedGpuMemory
except Exception:
pass # Skip errors for specific GPU queries
self.peak_gpu_vram = max(self.peak_gpu_vram, total_vram)
except Exception:
pass
except Exception:
pass
time.sleep(self.interval)
def __enter__(self):
if pynvml:
try:
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
self._gpu_handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in range(device_count)]
self._pynvml_available = True
except Exception as e:
logger.warning(f"Failed to initialize pynvml for VRAM monitoring: {e}")
self._thread = threading.Thread(target=self._monitor, daemon=True)
self._thread.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_event.set()
if self._thread:
self._thread.join()
if self._pynvml_available:
try:
pynvml.nvmlShutdown()
except Exception:
pass
def _generate_config_params():
attn_keys = [
"use_pytorch_cross_attention",
@ -94,13 +179,18 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu:
outputs = {}
start_time = time.time()
monitor = ResourceMonitor()
try:
outputs = await client.queue_prompt(prompt)
with monitor:
outputs = await client.queue_prompt(prompt)
except TorchAudioNotFoundError:
pytest.skip("requires torchaudio")
finally:
end_time = time.time()
logger.info(f"Test {workflow_name} with client {client} took {end_time - start_time:.4f}s")
duration = end_time - start_time
ram_gb = monitor.peak_cpu_ram / (1024**3)
vram_gb = monitor.peak_gpu_vram / (1024**3)
logger.info(f"Test {workflow_name} with client {client} took {duration:.4f}s | Peak RAM: {ram_gb:.2f} GB | Peak VRAM: {vram_gb:.2f} GB")
if any(v.class_type == "SaveImage" for v in prompt.values()):
save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")