diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index df2d8e827..3dcac3eef 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -699,12 +699,12 @@ class ModelPatcher: offloaded = [] offload_buffer = 0 loading.sort(reverse=True) - for x in loading: + for i, x in enumerate(loading): module_offload_mem, module_mem, n, m, params = x lowvram_weight = False - potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)) + potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]])) lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory weight_key = "{}.weight".format(n) @@ -876,14 +876,18 @@ class ModelPatcher: patch_counter = 0 unload_list = self._load_list() unload_list.sort() + offload_buffer = self.model.model_offload_buffer_memory + if len(unload_list) > 0: + NS = comfy.model_management.NUM_STREAMS + offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS for unload in unload_list: if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: break module_offload_mem, module_mem, n, m, params = unload - potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem) + potential_offload = module_offload_mem + sum(offload_weight_factor) lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: @@ -935,6 +939,8 @@ class ModelPatcher: m.comfy_patched_weights = False memory_freed += module_mem offload_buffer = max(offload_buffer, potential_offload) + offload_weight_factor.append(module_mem) + offload_weight_factor.pop(0) logging.debug("freed {}".format(n)) for param in params: diff --git a/comfy/sd.py b/comfy/sd.py index f9e5efab5..03bdb33d5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -193,6 +193,7 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) all_hooks.reset() self.patcher.patch_hooks(None) if show_pbar: @@ -240,6 +241,7 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) o = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = o[:2] if return_dict: @@ -469,7 +471,7 @@ class VAE: decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True @@ -481,8 +483,10 @@ class VAE: self.latent_dim = 3 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) - self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + #This is likely to significantly over-estimate with single image or low frame counts as the + #implementation is able to completely skip caching. Rework if used as an image only VAE + self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] elif "decoder.unpatcher3d.wavelets" in sd: self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 0fc9ab3db..503a51843 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -147,6 +147,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer_norm_hidden_state = layer_norm_hidden_state self.return_projected_pooled = return_projected_pooled self.return_attention_masks = return_attention_masks + self.execution_device = None if layer == "hidden": assert layer_idx is not None @@ -163,6 +164,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) + self.execution_device = options.get("execution_device", self.execution_device) if isinstance(self.layer, list) or self.layer == "all": pass elif layer_idx is None or abs(layer_idx) > self.num_layers: @@ -175,6 +177,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer = self.options_default[0] self.layer_idx = self.options_default[1] self.return_projected_pooled = self.options_default[2] + self.execution_device = None def process_tokens(self, tokens, device): end_token = self.special_tokens.get("end", None) @@ -258,7 +261,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info def forward(self, tokens): - device = self.transformer.get_input_embeddings().weight.device + if self.execution_device is None: + device = self.transformer.get_input_embeddings().weight.device + else: + device = self.execution_device + embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device) attention_mask_model = None diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 6d1bea599..5a75a3aae 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -319,9 +319,10 @@ class AudioSaveHelper: for key, value in metadata.items(): output_container.metadata[key] = value + layout = "mono" if waveform.shape[0] == 1 else "stereo" # Set up the output stream with appropriate properties if format == "opus": - out_stream = output_container.add_stream("libopus", rate=sample_rate) + out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout) if quality == "64k": out_stream.bit_rate = 64000 elif quality == "96k": @@ -333,7 +334,7 @@ class AudioSaveHelper: elif quality == "320k": out_stream.bit_rate = 320000 elif format == "mp3": - out_stream = output_container.add_stream("libmp3lame", rate=sample_rate) + out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout) if quality == "V0": # TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool out_stream.codec_context.qscale = 1 @@ -342,12 +343,12 @@ class AudioSaveHelper: elif quality == "320k": out_stream.bit_rate = 320000 else: # format == "flac": - out_stream = output_container.add_stream("flac", rate=sample_rate) + out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout) frame = av.AudioFrame.from_ndarray( waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format="flt", - layout="mono" if waveform.shape[0] == 1 else "stereo", + layout=layout, ) frame.sample_rate = sample_rate frame.pts = 0 diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 80292fb3c..4cc22abfb 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -47,6 +47,7 @@ from .validation_utils import ( validate_string, validate_video_dimensions, validate_video_duration, + validate_video_frame_count, ) __all__ = [ @@ -94,6 +95,7 @@ __all__ = [ "validate_string", "validate_video_dimensions", "validate_video_duration", + "validate_video_frame_count", # Misc functions "get_fs_object_size", ] diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 328fe5227..491e6b6a8 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -2,8 +2,8 @@ import asyncio import contextlib import os import time +from collections.abc import Callable from io import BytesIO -from typing import Callable, Optional, Union from comfy.cli_args import args from comfy.model_management import processing_interrupted @@ -35,12 +35,12 @@ def default_base_url() -> str: async def sleep_with_interrupt( seconds: float, - node_cls: Optional[type[IO.ComfyNode]], - label: Optional[str] = None, - start_ts: Optional[float] = None, - estimated_total: Optional[int] = None, + node_cls: type[IO.ComfyNode] | None, + label: str | None = None, + start_ts: float | None = None, + estimated_total: int | None = None, *, - display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None, + display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None, ): """ Sleep in 1s slices while: @@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str: return mime_type.split("/")[-1].lower() -def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int: +def get_fs_object_size(path_or_object: str | BytesIO) -> int: if isinstance(path_or_object, str): return os.path.getsize(path_or_object) return len(path_or_object.getvalue()) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index bf01d7d36..bf37cba5f 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -4,10 +4,11 @@ import json import logging import time import uuid +from collections.abc import Callable, Iterable from dataclasses import dataclass from enum import Enum from io import BytesIO -from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union +from typing import Any, Literal, TypeVar from urllib.parse import urljoin, urlparse import aiohttp @@ -37,8 +38,8 @@ class ApiEndpoint: path: str, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", *, - query_params: Optional[dict[str, Any]] = None, - headers: Optional[dict[str, str]] = None, + query_params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, ): self.path = path self.method = method @@ -52,18 +53,18 @@ class _RequestConfig: endpoint: ApiEndpoint timeout: float content_type: str - data: Optional[dict[str, Any]] - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] - multipart_parser: Optional[Callable] + data: dict[str, Any] | None + files: dict[str, Any] | list[tuple[str, Any]] | None + multipart_parser: Callable | None max_retries: int retry_delay: float retry_backoff: float wait_label: str = "Waiting" monitor_progress: bool = True - estimated_total: Optional[int] = None - final_label_on_success: Optional[str] = "Completed" - progress_origin_ts: Optional[float] = None - price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None + estimated_total: int | None = None + final_label_on_success: str | None = "Completed" + progress_origin_ts: float | None = None + price_extractor: Callable[[dict[str, Any]], float | None] | None = None @dataclass @@ -71,10 +72,10 @@ class _PollUIState: started: float status_label: str = "Queued" is_queued: bool = True - price: Optional[float] = None - estimated_duration: Optional[int] = None + price: float | None = None + estimated_duration: int | None = None base_processing_elapsed: float = 0.0 # sum of completed active intervals - active_since: Optional[float] = None # start time of current active interval (None if queued) + active_since: float | None = None # start time of current active interval (None if queued) _RETRY_STATUS = {408, 429, 500, 502, 503, 504} @@ -87,20 +88,20 @@ async def sync_op( cls: type[IO.ComfyNode], endpoint: ApiEndpoint, *, - response_model: Type[M], - price_extractor: Optional[Callable[[M], Optional[float]]] = None, - data: Optional[BaseModel] = None, - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + response_model: type[M], + price_extractor: Callable[[M | Any], float | None] | None = None, + data: BaseModel | None = None, + files: dict[str, Any] | list[tuple[str, Any]] | None = None, content_type: str = "application/json", timeout: float = 3600.0, - multipart_parser: Optional[Callable] = None, + multipart_parser: Callable | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, wait_label: str = "Waiting for server", - estimated_duration: Optional[int] = None, - final_label_on_success: Optional[str] = "Completed", - progress_origin_ts: Optional[float] = None, + estimated_duration: int | None = None, + final_label_on_success: str | None = "Completed", + progress_origin_ts: float | None = None, monitor_progress: bool = True, ) -> M: raw = await sync_op_raw( @@ -131,22 +132,22 @@ async def poll_op( cls: type[IO.ComfyNode], poll_endpoint: ApiEndpoint, *, - response_model: Type[M], - status_extractor: Callable[[M], Optional[Union[str, int]]], - progress_extractor: Optional[Callable[[M], Optional[int]]] = None, - price_extractor: Optional[Callable[[M], Optional[float]]] = None, - completed_statuses: Optional[list[Union[str, int]]] = None, - failed_statuses: Optional[list[Union[str, int]]] = None, - queued_statuses: Optional[list[Union[str, int]]] = None, - data: Optional[BaseModel] = None, + response_model: type[M], + status_extractor: Callable[[M | Any], str | int | None], + progress_extractor: Callable[[M | Any], int | None] | None = None, + price_extractor: Callable[[M | Any], float | None] | None = None, + completed_statuses: list[str | int] | None = None, + failed_statuses: list[str | int] | None = None, + queued_statuses: list[str | int] | None = None, + data: BaseModel | None = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, retry_backoff_per_poll: float = 2.0, - estimated_duration: Optional[int] = None, - cancel_endpoint: Optional[ApiEndpoint] = None, + estimated_duration: int | None = None, + cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, ) -> M: raw = await poll_op_raw( @@ -178,22 +179,22 @@ async def sync_op_raw( cls: type[IO.ComfyNode], endpoint: ApiEndpoint, *, - price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, - data: Optional[Union[dict[str, Any], BaseModel]] = None, - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + price_extractor: Callable[[dict[str, Any]], float | None] | None = None, + data: dict[str, Any] | BaseModel | None = None, + files: dict[str, Any] | list[tuple[str, Any]] | None = None, content_type: str = "application/json", timeout: float = 3600.0, - multipart_parser: Optional[Callable] = None, + multipart_parser: Callable | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, wait_label: str = "Waiting for server", - estimated_duration: Optional[int] = None, + estimated_duration: int | None = None, as_binary: bool = False, - final_label_on_success: Optional[str] = "Completed", - progress_origin_ts: Optional[float] = None, + final_label_on_success: str | None = "Completed", + progress_origin_ts: float | None = None, monitor_progress: bool = True, -) -> Union[dict[str, Any], bytes]: +) -> dict[str, Any] | bytes: """ Make a single network request. - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). @@ -229,21 +230,21 @@ async def poll_op_raw( cls: type[IO.ComfyNode], poll_endpoint: ApiEndpoint, *, - status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]], - progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, - price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, - completed_statuses: Optional[list[Union[str, int]]] = None, - failed_statuses: Optional[list[Union[str, int]]] = None, - queued_statuses: Optional[list[Union[str, int]]] = None, - data: Optional[Union[dict[str, Any], BaseModel]] = None, + status_extractor: Callable[[dict[str, Any]], str | int | None], + progress_extractor: Callable[[dict[str, Any]], int | None] | None = None, + price_extractor: Callable[[dict[str, Any]], float | None] | None = None, + completed_statuses: list[str | int] | None = None, + failed_statuses: list[str | int] | None = None, + queued_statuses: list[str | int] | None = None, + data: dict[str, Any] | BaseModel | None = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, retry_backoff_per_poll: float = 2.0, - estimated_duration: Optional[int] = None, - cancel_endpoint: Optional[ApiEndpoint] = None, + estimated_duration: int | None = None, + cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, ) -> dict[str, Any]: """ @@ -261,7 +262,7 @@ async def poll_op_raw( consumed_attempts = 0 # counts only non-queued polls progress_bar = utils.ProgressBar(100) if progress_extractor else None - last_progress: Optional[int] = None + last_progress: int | None = None state = _PollUIState(started=started, estimated_duration=estimated_duration) stop_ticker = asyncio.Event() @@ -420,10 +421,10 @@ async def poll_op_raw( def _display_text( node_cls: type[IO.ComfyNode], - text: Optional[str], + text: str | None, *, - status: Optional[Union[str, int]] = None, - price: Optional[float] = None, + status: str | int | None = None, + price: float | None = None, ) -> None: display_lines: list[str] = [] if status: @@ -440,13 +441,13 @@ def _display_text( def _display_time_progress( node_cls: type[IO.ComfyNode], - status: Optional[Union[str, int]], + status: str | int | None, elapsed_seconds: int, - estimated_total: Optional[int] = None, + estimated_total: int | None = None, *, - price: Optional[float] = None, - is_queued: Optional[bool] = None, - processing_elapsed_seconds: Optional[int] = None, + price: float | None = None, + is_queued: bool | None = None, + processing_elapsed_seconds: int | None = None, ) -> None: if estimated_total is not None and estimated_total > 0 and is_queued is False: pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds @@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: raise ValueError("files tuple must be (filename, file[, content_type])") -def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: +def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]: params = dict(endpoint_params or {}) if method.upper() == "GET" and data: for k, v in data.items(): @@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str: def _snapshot_request_body_for_logging( content_type: str, method: str, - data: Optional[dict[str, Any]], - files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]], -) -> Optional[Union[dict[str, Any], str]]: + data: dict[str, Any] | None, + files: dict[str, Any] | list[tuple[str, Any]] | None, +) -> dict[str, Any] | str | None: if method.upper() == "GET": return None if content_type == "multipart/form-data": @@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): attempt = 0 delay = cfg.retry_delay operation_succeeded: bool = False - final_elapsed_seconds: Optional[int] = None - extracted_price: Optional[float] = None + final_elapsed_seconds: int | None = None + extracted_price: float | None = None while True: attempt += 1 stop_event = asyncio.Event() - monitor_task: Optional[asyncio.Task] = None - sess: Optional[aiohttp.ClientSession] = None + monitor_task: asyncio.Task | None = None + sess: aiohttp.ClientSession | None = None operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) @@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): ) -def _validate_or_raise(response_model: Type[M], payload: Any) -> M: +def _validate_or_raise(response_model: type[M], payload: Any) -> M: try: return response_model.model_validate(payload) except Exception as e: @@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M: def _wrap_model_extractor( - response_model: Type[M], - extractor: Optional[Callable[[M], Any]], -) -> Optional[Callable[[dict[str, Any]], Any]]: + response_model: type[M], + extractor: Callable[[M], Any] | None, +) -> Callable[[dict[str, Any]], Any] | None: """Wrap a typed extractor so it can be used by the dict-based poller. Validates the dict into `response_model` before invoking `extractor`. Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating @@ -929,10 +930,10 @@ def _wrap_model_extractor( return _wrapped -def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]: +def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]: if not values: return set() - out: set[Union[str, int]] = set() + out: set[str | int] = set() for v in values: nv = _normalize_status_value(v) if nv is not None: @@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio return out -def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]: +def _normalize_status_value(val: str | int | None) -> str | int | None: if isinstance(val, str): return val.strip().lower() return val diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 971dc57de..c57457580 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -4,7 +4,6 @@ import math import mimetypes import uuid from io import BytesIO -from typing import Optional import av import numpy as np @@ -12,8 +11,7 @@ import torch from PIL import Image from comfy.utils import common_upscale -from comfy_api.latest import Input, InputImpl -from comfy_api.util import VideoCodec, VideoContainer +from comfy_api.latest import Input, InputImpl, Types from ._helpers import mimetype_to_extension @@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to def tensor_to_bytesio( image: torch.Tensor, - name: Optional[str] = None, + name: str | None = None, total_pixels: int = 2048 * 2048, mime_type: str = "image/png", ) -> BytesIO: @@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co def video_to_base64_string( video: Input.Video, - container_format: VideoContainer = None, - codec: VideoCodec = None + container_format: Types.VideoContainer | None = None, + codec: Types.VideoCodec | None = None, ) -> str: """ Converts a video input to a base64 string. @@ -189,12 +187,11 @@ def video_to_base64_string( codec: Optional codec to use (defaults to video.codec if available) """ video_bytes_io = BytesIO() - - # Use provided format/codec if specified, otherwise use video's own if available - format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) - codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) - - video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use) + video.save_to( + video_bytes_io, + format=container_format or getattr(video, "container", Types.VideoContainer.MP4), + codec=codec or getattr(video, "codec", Types.VideoCodec.H264), + ) video_bytes_io.seek(0) return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 14207dc68..3e0d0352d 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -3,15 +3,15 @@ import contextlib import uuid from io import BytesIO from pathlib import Path -from typing import IO, Optional, Union +from typing import IO from urllib.parse import urljoin, urlparse import aiohttp import torch from aiohttp.client_exceptions import ClientError, ContentTypeError -from comfy_api.input_impl import VideoFromFile from comfy_api.latest import IO as COMFY_IO +from comfy_api.latest import InputImpl from . import request_logger from ._helpers import ( @@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504} async def download_url_to_bytesio( url: str, - dest: Optional[Union[BytesIO, IO[bytes], str, Path]], + dest: BytesIO | IO[bytes] | str | Path | None, *, - timeout: Optional[float] = None, + timeout: float | None = None, max_retries: int = 5, retry_delay: float = 1.0, retry_backoff: float = 2.0, @@ -71,10 +71,10 @@ async def download_url_to_bytesio( is_path_sink = isinstance(dest, (str, Path)) fhandle = None - session: Optional[aiohttp.ClientSession] = None - stop_evt: Optional[asyncio.Event] = None - monitor_task: Optional[asyncio.Task] = None - req_task: Optional[asyncio.Task] = None + session: aiohttp.ClientSession | None = None + stop_evt: asyncio.Event | None = None + monitor_task: asyncio.Task | None = None + req_task: asyncio.Task | None = None try: with contextlib.suppress(Exception): @@ -234,11 +234,11 @@ async def download_url_to_video_output( timeout: float = None, max_retries: int = 5, cls: type[COMFY_IO.ComfyNode] = None, -) -> VideoFromFile: +) -> InputImpl.VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output.""" result = BytesIO() await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) - return VideoFromFile(result) + return InputImpl.VideoFromFile(result) async def download_url_as_bytesio( diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py index ac52e2eab..e0cb4428d 100644 --- a/comfy_api_nodes/util/request_logger.py +++ b/comfy_api_nodes/util/request_logger.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import datetime import hashlib import json diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 0532bea9a..b8d33f4d1 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -4,15 +4,13 @@ import logging import time import uuid from io import BytesIO -from typing import Optional from urllib.parse import urlparse import aiohttp import torch from pydantic import BaseModel, Field -from comfy_api.latest import IO, Input -from comfy_api.util import VideoCodec, VideoContainer +from comfy_api.latest import IO, Input, Types from . import request_logger from ._helpers import is_processing_interrupted, sleep_with_interrupt @@ -32,7 +30,7 @@ from .conversions import ( class UploadRequest(BaseModel): file_name: str = Field(..., description="Filename to upload") - content_type: Optional[str] = Field( + content_type: str | None = Field( None, description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", ) @@ -56,7 +54,7 @@ async def upload_images_to_comfyapi( Uploads images to ComfyUI API and returns download URLs. To upload multiple images, stack them in the batch dimension first. """ - # if batch, try to upload each file if max_images is greater than 0 + # if batched, try to upload each file if max_images is greater than 0 download_urls: list[str] = [] is_batch = len(image.shape) > 3 batch_len = image.shape[0] if is_batch else 1 @@ -100,9 +98,9 @@ async def upload_video_to_comfyapi( cls: type[IO.ComfyNode], video: Input.Video, *, - container: VideoContainer = VideoContainer.MP4, - codec: VideoCodec = VideoCodec.H264, - max_duration: Optional[int] = None, + container: Types.VideoContainer = Types.VideoContainer.MP4, + codec: Types.VideoCodec = Types.VideoCodec.H264, + max_duration: int | None = None, wait_label: str | None = "Uploading", ) -> str: """ @@ -220,7 +218,7 @@ async def upload_file( return monitor_task = asyncio.create_task(_monitor()) - sess: Optional[aiohttp.ClientSession] = None + sess: aiohttp.ClientSession | None = None try: try: request_logger.log_request_response( diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index ec7006aed..f01edea96 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -1,9 +1,7 @@ import logging -from typing import Optional import torch -from comfy_api.input.video_types import VideoInput from comfy_api.latest import Input @@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: def validate_image_dimensions( image: torch.Tensor, - min_width: Optional[int] = None, - max_width: Optional[int] = None, - min_height: Optional[int] = None, - max_height: Optional[int] = None, + min_width: int | None = None, + max_width: int | None = None, + min_height: int | None = None, + max_height: int | None = None, ): height, width = get_image_dimensions(image) @@ -37,8 +35,8 @@ def validate_image_dimensions( def validate_image_aspect_ratio( image: torch.Tensor, - min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) - max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) + min_ratio: tuple[float, float] | None = None, # e.g. (1, 4) + max_ratio: tuple[float, float] | None = None, # e.g. (4, 1) *, strict: bool = True, # True -> (min, max); False -> [min, max] ) -> float: @@ -54,8 +52,8 @@ def validate_image_aspect_ratio( def validate_images_aspect_ratio_closeness( first_image: torch.Tensor, second_image: torch.Tensor, - min_rel: float, # e.g. 0.8 - max_rel: float, # e.g. 1.25 + min_rel: float, # e.g. 0.8 + max_rel: float, # e.g. 1.25 *, strict: bool = False, # True -> (min, max); False -> [min, max] ) -> float: @@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness( def validate_aspect_ratio_string( aspect_ratio: str, - min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) - max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) + min_ratio: tuple[float, float] | None = None, # e.g. (1, 4) + max_ratio: tuple[float, float] | None = None, # e.g. (4, 1) *, strict: bool = False, # True -> (min, max); False -> [min, max] ) -> float: @@ -97,10 +95,10 @@ def validate_aspect_ratio_string( def validate_video_dimensions( video: Input.Video, - min_width: Optional[int] = None, - max_width: Optional[int] = None, - min_height: Optional[int] = None, - max_height: Optional[int] = None, + min_width: int | None = None, + max_width: int | None = None, + min_height: int | None = None, + max_height: int | None = None, ): try: width, height = video.get_dimensions() @@ -120,8 +118,8 @@ def validate_video_dimensions( def validate_video_duration( video: Input.Video, - min_duration: Optional[float] = None, - max_duration: Optional[float] = None, + min_duration: float | None = None, + max_duration: float | None = None, ): try: duration = video.get_duration() @@ -136,6 +134,23 @@ def validate_video_duration( raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s") +def validate_video_frame_count( + video: Input.Video, + min_frame_count: int | None = None, + max_frame_count: int | None = None, +): + try: + frame_count = video.get_frame_count() + except Exception as e: + logging.error("Error getting frame count of video: %s", e) + return + + if min_frame_count is not None and min_frame_count > frame_count: + raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}") + if max_frame_count is not None and frame_count > max_frame_count: + raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}") + + def get_number_of_images(images): if isinstance(images, torch.Tensor): return images.shape[0] if images.ndim >= 4 else 1 @@ -144,8 +159,8 @@ def get_number_of_images(images): def validate_audio_duration( audio: Input.Audio, - min_duration: Optional[float] = None, - max_duration: Optional[float] = None, + min_duration: float | None = None, + max_duration: float | None = None, ) -> None: sr = int(audio["sample_rate"]) dur = int(audio["waveform"].shape[-1]) / sr @@ -177,7 +192,7 @@ def validate_string( ) -def validate_container_format_is_mp4(video: VideoInput) -> None: +def validate_container_format_is_mp4(video: Input.Video) -> None: """Validates video container format is MP4.""" container_format = video.get_container_format() if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: @@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float: def _assert_ratio_bounds( ar: float, *, - min_ratio: Optional[tuple[float, float]] = None, - max_ratio: Optional[tuple[float, float]] = None, + min_ratio: tuple[float, float] | None = None, + max_ratio: tuple[float, float] | None = None, strict: bool = True, ) -> None: """Validate a numeric aspect ratio against optional min/max ratio bounds.""" diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 2ed7e0b22..812301fb7 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -6,65 +6,80 @@ import torch import comfy.model_management import folder_paths import os -import io -import json -import random import hashlib import node_helpers import logging -from comfy.cli_args import args -from comfy.comfy_types import FileLocator +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, UI -class EmptyLatentAudio: - def __init__(self): - self.device = comfy.model_management.intermediate_device() +class EmptyLatentAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentAudio", + display_name="Empty Latent Audio", + category="latent/audio", + inputs=[ + IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), + IO.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[IO.Latent.Output()], + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" - - CATEGORY = "latent/audio" - - def generate(self, seconds, batch_size): + def execute(cls, seconds, batch_size) -> IO.NodeOutput: length = round((seconds * 44100 / 2048) / 2) * 2 - latent = torch.zeros([batch_size, 64, length], device=self.device) - return ({"samples":latent, "type": "audio"}, ) + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) + return IO.NodeOutput({"samples":latent, "type": "audio"}) -class ConditioningStableAudio: + generate = execute # TODO: remove + + +class ConditioningStableAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}), - "seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}), - }} + def define_schema(cls): + return IO.Schema( + node_id="ConditioningStableAudio", + category="conditioning", + inputs=[ + IO.Conditioning.Input("positive"), + IO.Conditioning.Input("negative"), + IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1), + IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ], + ) - RETURN_TYPES = ("CONDITIONING","CONDITIONING") - RETURN_NAMES = ("positive", "negative") - - FUNCTION = "append" - - CATEGORY = "conditioning" - - def append(self, positive, negative, seconds_start, seconds_total): + @classmethod + def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput: positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total}) negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total}) - return (positive, negative) + return IO.NodeOutput(positive, negative) -class VAEEncodeAudio: + append = execute # TODO: remove + + +class VAEEncodeAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" + def define_schema(cls): + return IO.Schema( + node_id="VAEEncodeAudio", + display_name="VAE Encode Audio", + category="latent/audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Latent.Output()], + ) - CATEGORY = "latent/audio" - - def encode(self, vae, audio): + @classmethod + def execute(cls, vae, audio) -> IO.NodeOutput: sample_rate = audio["sample_rate"] if 44100 != sample_rate: waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) @@ -72,213 +87,134 @@ class VAEEncodeAudio: waveform = audio["waveform"] t = vae.encode(waveform.movedim(1, -1)) - return ({"samples":t}, ) + return IO.NodeOutput({"samples":t}) -class VAEDecodeAudio: + encode = execute # TODO: remove + + +class VAEDecodeAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} - RETURN_TYPES = ("AUDIO",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeAudio", + display_name="VAE Decode Audio", + category="latent/audio", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Audio.Output()], + ) - CATEGORY = "latent/audio" - - def decode(self, vae, samples): + @classmethod + def execute(cls, vae, samples) -> IO.NodeOutput: audio = vae.decode(samples["samples"]).movedim(-1, 1) std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return ({"waveform": audio, "sample_rate": 44100}, ) + return IO.NodeOutput({"waveform": audio, "sample_rate": 44100}) + + decode = execute # TODO: remove -def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"): - - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results: list[FileLocator] = [] - - # Prepare metadata dictionary - metadata = {} - if not args.disable_metadata: - if prompt is not None: - metadata["prompt"] = json.dumps(prompt) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) - - # Opus supported sample rates - OPUS_RATES = [8000, 12000, 16000, 24000, 48000] - - for (batch_number, waveform) in enumerate(audio["waveform"].cpu()): - filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) - file = f"{filename_with_batch_num}_{counter:05}_.{format}" - output_path = os.path.join(full_output_folder, file) - - # Use original sample rate initially - sample_rate = audio["sample_rate"] - - # Handle Opus sample rate requirements - if format == "opus": - if sample_rate > 48000: - sample_rate = 48000 - elif sample_rate not in OPUS_RATES: - # Find the next highest supported rate - for rate in sorted(OPUS_RATES): - if rate > sample_rate: - sample_rate = rate - break - if sample_rate not in OPUS_RATES: # Fallback if still not supported - sample_rate = 48000 - - # Resample if necessary - if sample_rate != audio["sample_rate"]: - waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate) - - # Create output with specified format - output_buffer = io.BytesIO() - output_container = av.open(output_buffer, mode='w', format=format) - - # Set metadata on the container - for key, value in metadata.items(): - output_container.metadata[key] = value - - layout = 'mono' if waveform.shape[0] == 1 else 'stereo' - # Set up the output stream with appropriate properties - if format == "opus": - out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout) - if quality == "64k": - out_stream.bit_rate = 64000 - elif quality == "96k": - out_stream.bit_rate = 96000 - elif quality == "128k": - out_stream.bit_rate = 128000 - elif quality == "192k": - out_stream.bit_rate = 192000 - elif quality == "320k": - out_stream.bit_rate = 320000 - elif format == "mp3": - out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout) - if quality == "V0": - #TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool - out_stream.codec_context.qscale = 1 - elif quality == "128k": - out_stream.bit_rate = 128000 - elif quality == "320k": - out_stream.bit_rate = 320000 - else: #format == "flac": - out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout) - - frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout) - frame.sample_rate = sample_rate - frame.pts = 0 - output_container.mux(out_stream.encode(frame)) - - # Flush encoder - output_container.mux(out_stream.encode(None)) - - # Close containers - output_container.close() - - # Write the output to file - output_buffer.seek(0) - with open(output_path, 'wb') as f: - f.write(output_buffer.getbuffer()) - - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 - - return { "ui": { "audio": results } } - -class SaveAudio: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudio", + display_name="Save Audio (FLAC)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format) + ) - RETURN_TYPES = () - FUNCTION = "save_flac" + save_flac = execute # TODO: remove - OUTPUT_NODE = True - CATEGORY = "audio" - - def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo) - -class SaveAudioMP3: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudioMP3(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioMP3", + display_name="Save Audio (MP3)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - "quality": (["V0", "128k", "320k"], {"default": "V0"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui( + audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality + ) + ) - RETURN_TYPES = () - FUNCTION = "save_mp3" + save_mp3 = execute # TODO: remove - OUTPUT_NODE = True - CATEGORY = "audio" - - def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) - -class SaveAudioOpus: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudioOpus(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioOpus", + display_name="Save Audio (Opus)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": { "audio": ("AUDIO", ), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - "quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui( + audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality + ) + ) - RETURN_TYPES = () - FUNCTION = "save_opus" + save_opus = execute # TODO: remove - OUTPUT_NODE = True - CATEGORY = "audio" - - def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) - -class PreviewAudio(SaveAudio): - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() - self.type = "temp" - self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) +class PreviewAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PreviewAudio", + display_name="Preview Audio", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": - {"audio": ("AUDIO", ), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio) -> IO.NodeOutput: + return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls)) + + save_flac = execute # TODO: remove + def f32_pcm(wav: torch.Tensor) -> torch.Tensor: """Convert audio to float 32 bits PCM format.""" @@ -316,26 +252,30 @@ def load(filepath: str) -> tuple[torch.Tensor, int]: wav = f32_pcm(wav) return wav, sr -class LoadAudio: +class LoadAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): + def define_schema(cls): input_dir = folder_paths.get_input_directory() files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) - return {"required": {"audio": (sorted(files), {"audio_upload": True})}} + return IO.Schema( + node_id="LoadAudio", + display_name="Load Audio", + category="audio", + inputs=[ + IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)), + ], + outputs=[IO.Audio.Output()], + ) - CATEGORY = "audio" - - RETURN_TYPES = ("AUDIO", ) - FUNCTION = "load" - - def load(self, audio): + @classmethod + def execute(cls, audio) -> IO.NodeOutput: audio_path = folder_paths.get_annotated_filepath(audio) waveform, sample_rate = load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} - return (audio, ) + return IO.NodeOutput(audio) @classmethod - def IS_CHANGED(s, audio): + def fingerprint_inputs(cls, audio): image_path = folder_paths.get_annotated_filepath(audio) m = hashlib.sha256() with open(image_path, 'rb') as f: @@ -343,46 +283,69 @@ class LoadAudio: return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, audio): + def validate_inputs(cls, audio): if not folder_paths.exists_annotated_filepath(audio): return "Invalid audio file: {}".format(audio) return True -class RecordAudio: + load = execute # TODO: remove + + +class RecordAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"audio": ("AUDIO_RECORD", {})}} + def define_schema(cls): + return IO.Schema( + node_id="RecordAudio", + display_name="Record Audio", + category="audio", + inputs=[ + IO.Custom("AUDIO_RECORD").Input("audio"), + ], + outputs=[IO.Audio.Output()], + ) - CATEGORY = "audio" - - RETURN_TYPES = ("AUDIO", ) - FUNCTION = "load" - - def load(self, audio): + @classmethod + def execute(cls, audio) -> IO.NodeOutput: audio_path = folder_paths.get_annotated_filepath(audio) waveform, sample_rate = load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} - return (audio, ) + return IO.NodeOutput(audio) + + load = execute # TODO: remove -class TrimAudioDuration: +class TrimAudioDuration(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "audio": ("AUDIO",), - "start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}), - "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TrimAudioDuration", + display_name="Trim Audio Duration", + description="Trim audio tensor into chosen time range.", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Float.Input( + "start_index", + default=0.0, + min=-0xffffffffffffffff, + max=0xffffffffffffffff, + step=0.01, + tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).", + ), + IO.Float.Input( + "duration", + default=60.0, + min=0.0, + step=0.01, + tooltip="Duration in seconds", + ), + ], + outputs=[IO.Audio.Output()], + ) - FUNCTION = "trim" - RETURN_TYPES = ("AUDIO",) - CATEGORY = "audio" - DESCRIPTION = "Trim audio tensor into chosen time range." - - def trim(self, audio, start_index, duration): + @classmethod + def execute(cls, audio, start_index, duration) -> IO.NodeOutput: waveform = audio["waveform"] sample_rate = audio["sample_rate"] audio_length = waveform.shape[-1] @@ -399,23 +362,30 @@ class TrimAudioDuration: if start_frame >= end_frame: raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.") - return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},) + return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate}) + + trim = execute # TODO: remove -class SplitAudioChannels: +class SplitAudioChannels(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "audio": ("AUDIO",), - }} + def define_schema(cls): + return IO.Schema( + node_id="SplitAudioChannels", + display_name="Split Audio Channels", + description="Separates the audio into left and right channels.", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + ], + outputs=[ + IO.Audio.Output(display_name="left"), + IO.Audio.Output(display_name="right"), + ], + ) - RETURN_TYPES = ("AUDIO", "AUDIO") - RETURN_NAMES = ("left", "right") - FUNCTION = "separate" - CATEGORY = "audio" - DESCRIPTION = "Separates the audio into left and right channels." - - def separate(self, audio): + @classmethod + def execute(cls, audio) -> IO.NodeOutput: waveform = audio["waveform"] sample_rate = audio["sample_rate"] @@ -425,7 +395,9 @@ class SplitAudioChannels: left_channel = waveform[..., 0:1, :] right_channel = waveform[..., 1:2, :] - return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) + return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) + + separate = execute # TODO: remove def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): @@ -443,21 +415,29 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_ return waveform_1, waveform_2, output_sample_rate -class AudioConcat: +class AudioConcat(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "audio1": ("AUDIO",), - "audio2": ("AUDIO",), - "direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}), - }} + def define_schema(cls): + return IO.Schema( + node_id="AudioConcat", + display_name="Audio Concat", + description="Concatenates the audio1 to audio2 in the specified direction.", + category="audio", + inputs=[ + IO.Audio.Input("audio1"), + IO.Audio.Input("audio2"), + IO.Combo.Input( + "direction", + options=['after', 'before'], + default="after", + tooltip="Whether to append audio2 after or before audio1.", + ) + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "concat" - CATEGORY = "audio" - DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction." - - def concat(self, audio1, audio2, direction): + @classmethod + def execute(cls, audio1, audio2, direction) -> IO.NodeOutput: waveform_1 = audio1["waveform"] waveform_2 = audio2["waveform"] sample_rate_1 = audio1["sample_rate"] @@ -477,26 +457,33 @@ class AudioConcat: elif direction == 'before': concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2) - return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},) + return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate}) + + concat = execute # TODO: remove -class AudioMerge: +class AudioMerge(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "audio1": ("AUDIO",), - "audio2": ("AUDIO",), - "merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="AudioMerge", + display_name="Audio Merge", + description="Combine two audio tracks by overlaying their waveforms.", + category="audio", + inputs=[ + IO.Audio.Input("audio1"), + IO.Audio.Input("audio2"), + IO.Combo.Input( + "merge_method", + options=["add", "mean", "subtract", "multiply"], + tooltip="The method used to combine the audio waveforms.", + ) + ], + outputs=[IO.Audio.Output()], + ) - FUNCTION = "merge" - RETURN_TYPES = ("AUDIO",) - CATEGORY = "audio" - DESCRIPTION = "Combine two audio tracks by overlaying their waveforms." - - def merge(self, audio1, audio2, merge_method): + @classmethod + def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput: waveform_1 = audio1["waveform"] waveform_2 = audio2["waveform"] sample_rate_1 = audio1["sample_rate"] @@ -530,85 +517,108 @@ class AudioMerge: if max_val > 1.0: waveform = waveform / max_val - return ({"waveform": waveform, "sample_rate": output_sample_rate},) + return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate}) + + merge = execute # TODO: remove -class AudioAdjustVolume: +class AudioAdjustVolume(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "audio": ("AUDIO",), - "volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}), - }} + def define_schema(cls): + return IO.Schema( + node_id="AudioAdjustVolume", + display_name="Audio Adjust Volume", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Int.Input( + "volume", + default=1, + min=-100, + max=100, + tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc", + ) + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "adjust_volume" - CATEGORY = "audio" - - def adjust_volume(self, audio, volume): + @classmethod + def execute(cls, audio, volume) -> IO.NodeOutput: if volume == 0: - return (audio,) + return IO.NodeOutput(audio) waveform = audio["waveform"] sample_rate = audio["sample_rate"] gain = 10 ** (volume / 20) waveform = waveform * gain - return ({"waveform": waveform, "sample_rate": sample_rate},) + return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate}) + + adjust_volume = execute # TODO: remove -class EmptyAudio: +class EmptyAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}), - "sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}), - "channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}), - }} + def define_schema(cls): + return IO.Schema( + node_id="EmptyAudio", + display_name="Empty Audio", + category="audio", + inputs=[ + IO.Float.Input( + "duration", + default=60.0, + min=0.0, + max=0xffffffffffffffff, + step=0.01, + tooltip="Duration of the empty audio clip in seconds", + ), + IO.Float.Input( + "sample_rate", + default=44100, + tooltip="Sample rate of the empty audio clip.", + ), + IO.Float.Input( + "channels", + default=2, + min=1, + max=2, + tooltip="Number of audio channels (1 for mono, 2 for stereo).", + ), + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "create_empty_audio" - CATEGORY = "audio" - - def create_empty_audio(self, duration, sample_rate, channels): + @classmethod + def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput: num_samples = int(round(duration * sample_rate)) waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32) - return ({"waveform": waveform, "sample_rate": sample_rate},) + return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate}) + + create_empty_audio = execute # TODO: remove -NODE_CLASS_MAPPINGS = { - "EmptyLatentAudio": EmptyLatentAudio, - "VAEEncodeAudio": VAEEncodeAudio, - "VAEDecodeAudio": VAEDecodeAudio, - "SaveAudio": SaveAudio, - "SaveAudioMP3": SaveAudioMP3, - "SaveAudioOpus": SaveAudioOpus, - "LoadAudio": LoadAudio, - "PreviewAudio": PreviewAudio, - "ConditioningStableAudio": ConditioningStableAudio, - "RecordAudio": RecordAudio, - "TrimAudioDuration": TrimAudioDuration, - "SplitAudioChannels": SplitAudioChannels, - "AudioConcat": AudioConcat, - "AudioMerge": AudioMerge, - "AudioAdjustVolume": AudioAdjustVolume, - "EmptyAudio": EmptyAudio, -} +class AudioExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + EmptyLatentAudio, + VAEEncodeAudio, + VAEDecodeAudio, + SaveAudio, + SaveAudioMP3, + SaveAudioOpus, + LoadAudio, + PreviewAudio, + ConditioningStableAudio, + RecordAudio, + TrimAudioDuration, + SplitAudioChannels, + AudioConcat, + AudioMerge, + AudioAdjustVolume, + EmptyAudio, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "EmptyLatentAudio": "Empty Latent Audio", - "VAEEncodeAudio": "VAE Encode Audio", - "VAEDecodeAudio": "VAE Decode Audio", - "PreviewAudio": "Preview Audio", - "LoadAudio": "Load Audio", - "SaveAudio": "Save Audio (FLAC)", - "SaveAudioMP3": "Save Audio (MP3)", - "SaveAudioOpus": "Save Audio (Opus)", - "RecordAudio": "Record Audio", - "TrimAudioDuration": "Trim Audio Duration", - "SplitAudioChannels": "Split Audio Channels", - "AudioConcat": "Audio Concat", - "AudioMerge": "Audio Merge", - "AudioAdjustVolume": "Audio Adjust Volume", - "EmptyAudio": "Empty Audio", -} +async def comfy_entrypoint() -> AudioExtension: + return AudioExtension() diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index cb24ab709..19b8baaf4 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -623,7 +623,7 @@ class TrainLoraNode(io.ComfyNode): noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) if multi_res: # use first latent as dummy latent if multi_res - latents = latents[0].repeat(num_images, 1, 1, 1) + latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1))) guider.sample( noise.generate_noise({"samples": latents}), latents,