diff --git a/comfy_api_nodes/nodes_sonilo.py b/comfy_api_nodes/nodes_sonilo.py new file mode 100644 index 000000000..5518f5902 --- /dev/null +++ b/comfy_api_nodes/nodes_sonilo.py @@ -0,0 +1,287 @@ +import base64 +import json +import logging +import time +from urllib.parse import urljoin + +import aiohttp +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.util import ( + ApiEndpoint, + audio_bytes_to_audio_input, + upload_video_to_comfyapi, + validate_string, +) +from comfy_api_nodes.util._helpers import ( + default_base_url, + get_auth_header, + get_node_id, + is_processing_interrupted, +) +from comfy_api_nodes.util.common_exceptions import ProcessingInterrupted +from server import PromptServer + +logger = logging.getLogger(__name__) + + +class SoniloVideoToMusic(IO.ComfyNode): + """Generate music from video using Sonilo's AI model.""" + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="SoniloVideoToMusic", + display_name="Sonilo Video to Music", + category="api node/audio/Sonilo", + description="Generate music from video content using Sonilo's AI model. " + "Analyzes the video and creates matching music.", + inputs=[ + IO.Video.Input( + "video", + tooltip="Input video to generate music from. Maximum duration: 6 minutes.", + ), + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Optional text prompt to guide music generation. " + "Leave empty for best quality - the model will fully analyze the video content.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed for reproducibility. Currently ignored by the Sonilo " + "service but kept for graph consistency.", + ), + ], + outputs=[IO.Audio.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr='{"type":"usd","usd":0.009,"format":{"suffix":"/second"}}', + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + prompt: str = "", + seed: int = 0, + ) -> IO.NodeOutput: + video_url = await upload_video_to_comfyapi(cls, video, max_duration=360) + form = aiohttp.FormData() + form.add_field("video_url", video_url) + if prompt.strip(): + form.add_field("prompt", prompt.strip()) + audio_bytes = await _stream_sonilo_music( + cls, + ApiEndpoint(path="/proxy/sonilo/v2m/generate", method="POST"), + form, + ) + return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes)) + + +class SoniloTextToMusic(IO.ComfyNode): + """Generate music from a text prompt using Sonilo's AI model.""" + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="SoniloTextToMusic", + display_name="Sonilo Text to Music", + category="api node/audio/Sonilo", + description="Generate music from a text prompt using Sonilo's AI model. " + "Leave duration at 0 to let the model infer it from the prompt.", + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text prompt describing the music to generate.", + ), + IO.Int.Input( + "duration", + default=0, + min=0, + max=360, + tooltip="Target duration in seconds. Set to 0 to let the model " + "infer the duration from the prompt. Maximum: 6 minutes.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed for reproducibility. Currently ignored by the Sonilo " + "service but kept for graph consistency.", + ), + ], + outputs=[IO.Audio.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr=""" + ( + widgets.duration > 0 + ? {"type":"usd","usd": 0.005 * widgets.duration} + : {"type":"usd","usd": 0.005, "format":{"suffix":"/second"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + duration: int = 0, + seed: int = 0, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + form = aiohttp.FormData() + form.add_field("prompt", prompt) + if duration > 0: + form.add_field("duration", str(duration)) + audio_bytes = await _stream_sonilo_music( + cls, + ApiEndpoint(path="/proxy/sonilo/t2m/generate", method="POST"), + form, + ) + return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes)) + + +async def _stream_sonilo_music( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + form: aiohttp.FormData, +) -> bytes: + """POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes.""" + url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/")) + + headers: dict[str, str] = {} + headers.update(get_auth_header(cls)) + headers.update(endpoint.headers) + + node_id = get_node_id(cls) + start_ts = time.monotonic() + last_chunk_status_ts = 0.0 + audio_streams: dict[int, list[bytes]] = {} + title: str | None = None + + timeout = aiohttp.ClientTimeout(total=1200.0, sock_read=300.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + PromptServer.instance.send_progress_text("Status: Queued", node_id) + async with session.post(url, data=form, headers=headers) as resp: + if resp.status >= 400: + msg = await _extract_error_message(resp) + raise Exception(f"Sonilo API error ({resp.status}): {msg}") + + while True: + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + + raw_line = await resp.content.readline() + if not raw_line: + break + + line = raw_line.decode("utf-8").strip() + if not line: + continue + + try: + evt = json.loads(line) + except json.JSONDecodeError: + logger.warning("Sonilo: skipping malformed NDJSON line") + continue + + evt_type = evt.get("type") + if evt_type == "error": + code = evt.get("code", "UNKNOWN") + message = evt.get("message", "Unknown error") + raise Exception(f"Sonilo generation error ({code}): {message}") + if evt_type == "duration": + duration_sec = evt.get("duration_sec") + if duration_sec is not None: + PromptServer.instance.send_progress_text( + f"Status: Generating\nVideo duration: {duration_sec:.1f}s", + node_id, + ) + elif evt_type in ("titles", "title"): + # v2m sends a "titles" list, t2m sends a scalar "title" + if evt_type == "titles": + titles = evt.get("titles", []) + if titles: + title = titles[0] + else: + title = evt.get("title") or title + if title: + PromptServer.instance.send_progress_text( + f"Status: Generating\nTitle: {title}", + node_id, + ) + elif evt_type == "audio_chunk": + stream_idx = evt.get("stream_index", 0) + chunk_data = base64.b64decode(evt["data"]) + + if stream_idx not in audio_streams: + audio_streams[stream_idx] = [] + audio_streams[stream_idx].append(chunk_data) + + now = time.monotonic() + if now - last_chunk_status_ts >= 1.0: + total_chunks = sum(len(chunks) for chunks in audio_streams.values()) + elapsed = int(now - start_ts) + status_lines = ["Status: Receiving audio"] + if title: + status_lines.append(f"Title: {title}") + status_lines.append(f"Chunks received: {total_chunks}") + status_lines.append(f"Time elapsed: {elapsed}s") + PromptServer.instance.send_progress_text("\n".join(status_lines), node_id) + last_chunk_status_ts = now + elif evt_type == "complete": + break + + if not audio_streams: + raise Exception("Sonilo API returned no audio data.") + + PromptServer.instance.send_progress_text("Status: Completed", node_id) + selected_stream = 0 if 0 in audio_streams else min(audio_streams) + return b"".join(audio_streams[selected_stream]) + + +async def _extract_error_message(resp: aiohttp.ClientResponse) -> str: + """Extract a human-readable error message from an HTTP error response.""" + try: + error_body = await resp.json() + detail = error_body.get("detail", {}) + if isinstance(detail, dict): + return detail.get("message", str(detail)) + return str(detail) + except Exception: + return await resp.text() + + +class SoniloExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [SoniloVideoToMusic, SoniloTextToMusic] + + +async def comfy_entrypoint() -> SoniloExtension: + return SoniloExtension()