diff --git a/comfy/model_base.py b/comfy/model_base.py index 0b132ee61..42b50669a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1019,7 +1019,7 @@ class LTXV(BaseModel): class HunyuanVideo(BaseModel): - def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None, image_model=None): super().__init__(model_config, model_type, device=device, unet_model=HunyuanVideoModel) def encode_adm(self, **kwargs): diff --git a/comfy_extras/nodes/nodes_post_processing.py b/comfy_extras/nodes/nodes_post_processing.py index 2d6208db1..3a962ad4d 100644 --- a/comfy_extras/nodes/nodes_post_processing.py +++ b/comfy_extras/nodes/nodes_post_processing.py @@ -228,7 +228,7 @@ class ImageScaleToTotalPixels(io.ComfyNode): io.Image.Input("image"), io.Combo.Input("upscale_method", options=cls.upscale_methods), io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), - io.Int.Input("resolution_steps", default=1, min=1, max=256), + io.Int.Input("resolution_steps", default=1, min=1, max=256, optional=True), ], outputs=[ io.Image.Output(), @@ -236,7 +236,7 @@ class ImageScaleToTotalPixels(io.ComfyNode): ) @classmethod - def execute(cls, image, upscale_method, megapixels, resolution_steps) -> io.NodeOutput: + def execute(cls, image, upscale_method, megapixels, resolution_steps=1) -> io.NodeOutput: samples = image.movedim(-1, 1) total = megapixels * 1024 * 1024 diff --git a/pyproject.toml b/pyproject.toml index 01f6da028..42475cd16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,6 +143,7 @@ dev = [ "coverage", "pylint", "astroid", + "nvidia-ml-py", ] cpu = [ diff --git a/tests/inference/test_workflows.py b/tests/inference/test_workflows.py index b62a8892c..9e3f97024 100644 --- a/tests/inference/test_workflows.py +++ b/tests/inference/test_workflows.py @@ -4,6 +4,13 @@ import logging import time from importlib.abc import Traversable from typing import Any, AsyncGenerator +import threading +import psutil + +try: + import pynvml +except ImportError: + pynvml = None import pytest @@ -21,6 +28,84 @@ from comfy.cli_args_types import PerformanceFeature logger = logging.getLogger(__name__) +class ResourceMonitor: + def __init__(self, interval: float = 0.1): + self.interval = interval + self.peak_cpu_ram = 0 + self.peak_gpu_vram = 0 + self._stop_event = threading.Event() + self._thread = None + self._pynvml_available = False + self._gpu_handles = [] + + def _monitor(self): + current_process = psutil.Process() + while not self._stop_event.is_set(): + # Monitor CPU RAM (RSS) for process tree + try: + children = current_process.children(recursive=True) + processes = [current_process] + children + pids = {p.pid for p in processes} + + total_rss = 0 + for p in processes: + try: + total_rss += p.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + self.peak_cpu_ram = max(self.peak_cpu_ram, total_rss) + + # Monitor GPU VRAM if available + if self._pynvml_available and self._gpu_handles: + total_vram = 0 + try: + # Iterate over all detected GPUs + for handle in self._gpu_handles: + # Get all processes running on the GPU + try: + compute_procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) + graphics_procs = pynvml.nvmlDeviceGetGraphicsRunningProcesses(handle) + + # Filter for our process tree + for p in compute_procs + graphics_procs: + if p.pid in pids: + total_vram += p.usedGpuMemory + except Exception: + pass # Skip errors for specific GPU queries + + self.peak_gpu_vram = max(self.peak_gpu_vram, total_vram) + except Exception: + pass + except Exception: + pass + + time.sleep(self.interval) + + def __enter__(self): + if pynvml: + try: + pynvml.nvmlInit() + device_count = pynvml.nvmlDeviceGetCount() + self._gpu_handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in range(device_count)] + self._pynvml_available = True + except Exception as e: + logger.warning(f"Failed to initialize pynvml for VRAM monitoring: {e}") + + self._thread = threading.Thread(target=self._monitor, daemon=True) + self._thread.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._stop_event.set() + if self._thread: + self._thread.join() + if self._pynvml_available: + try: + pynvml.nvmlShutdown() + except Exception: + pass + + def _generate_config_params(): attn_keys = [ "use_pytorch_cross_attention", @@ -94,13 +179,18 @@ async def test_workflow(workflow_name: str, workflow_file: Traversable, has_gpu: outputs = {} start_time = time.time() + monitor = ResourceMonitor() try: - outputs = await client.queue_prompt(prompt) + with monitor: + outputs = await client.queue_prompt(prompt) except TorchAudioNotFoundError: pytest.skip("requires torchaudio") finally: end_time = time.time() - logger.info(f"Test {workflow_name} with client {client} took {end_time - start_time:.4f}s") + duration = end_time - start_time + ram_gb = monitor.peak_cpu_ram / (1024**3) + vram_gb = monitor.peak_gpu_vram / (1024**3) + logger.info(f"Test {workflow_name} with client {client} took {duration:.4f}s | Peak RAM: {ram_gb:.2f} GB | Peak VRAM: {vram_gb:.2f} GB") if any(v.class_type == "SaveImage" for v in prompt.values()): save_image_node_id = next(key for key in prompt if prompt[key].class_type == "SaveImage")