mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 21:12:30 +08:00
* feat(api-nodes): add Sonilo nodes Signed-off-by: bigcat88 <bigcat88@icloud.com> * fix: do not spam frontend with each chunk arrival Signed-off-by: bigcat88 <bigcat88@icloud.com> * updated pricing badge Signed-off-by: bigcat88 <bigcat88@icloud.com> --------- Signed-off-by: bigcat88 <bigcat88@icloud.com>
288 lines
10 KiB
Python
288 lines
10 KiB
Python
import base64
|
|
import json
|
|
import logging
|
|
import time
|
|
from urllib.parse import urljoin
|
|
|
|
import aiohttp
|
|
from typing_extensions import override
|
|
|
|
from comfy_api.latest import IO, ComfyExtension, Input
|
|
from comfy_api_nodes.util import (
|
|
ApiEndpoint,
|
|
audio_bytes_to_audio_input,
|
|
upload_video_to_comfyapi,
|
|
validate_string,
|
|
)
|
|
from comfy_api_nodes.util._helpers import (
|
|
default_base_url,
|
|
get_auth_header,
|
|
get_node_id,
|
|
is_processing_interrupted,
|
|
)
|
|
from comfy_api_nodes.util.common_exceptions import ProcessingInterrupted
|
|
from server import PromptServer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SoniloVideoToMusic(IO.ComfyNode):
|
|
"""Generate music from video using Sonilo's AI model."""
|
|
|
|
@classmethod
|
|
def define_schema(cls) -> IO.Schema:
|
|
return IO.Schema(
|
|
node_id="SoniloVideoToMusic",
|
|
display_name="Sonilo Video to Music",
|
|
category="api node/audio/Sonilo",
|
|
description="Generate music from video content using Sonilo's AI model. "
|
|
"Analyzes the video and creates matching music.",
|
|
inputs=[
|
|
IO.Video.Input(
|
|
"video",
|
|
tooltip="Input video to generate music from. Maximum duration: 6 minutes.",
|
|
),
|
|
IO.String.Input(
|
|
"prompt",
|
|
default="",
|
|
multiline=True,
|
|
tooltip="Optional text prompt to guide music generation. "
|
|
"Leave empty for best quality - the model will fully analyze the video content.",
|
|
),
|
|
IO.Int.Input(
|
|
"seed",
|
|
default=0,
|
|
min=0,
|
|
max=0xFFFFFFFFFFFFFFFF,
|
|
control_after_generate=True,
|
|
tooltip="Seed for reproducibility. Currently ignored by the Sonilo "
|
|
"service but kept for graph consistency.",
|
|
),
|
|
],
|
|
outputs=[IO.Audio.Output()],
|
|
hidden=[
|
|
IO.Hidden.auth_token_comfy_org,
|
|
IO.Hidden.api_key_comfy_org,
|
|
IO.Hidden.unique_id,
|
|
],
|
|
is_api_node=True,
|
|
price_badge=IO.PriceBadge(
|
|
expr='{"type":"usd","usd":0.009,"format":{"suffix":"/second"}}',
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
async def execute(
|
|
cls,
|
|
video: Input.Video,
|
|
prompt: str = "",
|
|
seed: int = 0,
|
|
) -> IO.NodeOutput:
|
|
video_url = await upload_video_to_comfyapi(cls, video, max_duration=360)
|
|
form = aiohttp.FormData()
|
|
form.add_field("video_url", video_url)
|
|
if prompt.strip():
|
|
form.add_field("prompt", prompt.strip())
|
|
audio_bytes = await _stream_sonilo_music(
|
|
cls,
|
|
ApiEndpoint(path="/proxy/sonilo/v2m/generate", method="POST"),
|
|
form,
|
|
)
|
|
return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes))
|
|
|
|
|
|
class SoniloTextToMusic(IO.ComfyNode):
|
|
"""Generate music from a text prompt using Sonilo's AI model."""
|
|
|
|
@classmethod
|
|
def define_schema(cls) -> IO.Schema:
|
|
return IO.Schema(
|
|
node_id="SoniloTextToMusic",
|
|
display_name="Sonilo Text to Music",
|
|
category="api node/audio/Sonilo",
|
|
description="Generate music from a text prompt using Sonilo's AI model. "
|
|
"Leave duration at 0 to let the model infer it from the prompt.",
|
|
inputs=[
|
|
IO.String.Input(
|
|
"prompt",
|
|
default="",
|
|
multiline=True,
|
|
tooltip="Text prompt describing the music to generate.",
|
|
),
|
|
IO.Int.Input(
|
|
"duration",
|
|
default=0,
|
|
min=0,
|
|
max=360,
|
|
tooltip="Target duration in seconds. Set to 0 to let the model "
|
|
"infer the duration from the prompt. Maximum: 6 minutes.",
|
|
),
|
|
IO.Int.Input(
|
|
"seed",
|
|
default=0,
|
|
min=0,
|
|
max=0xFFFFFFFFFFFFFFFF,
|
|
control_after_generate=True,
|
|
tooltip="Seed for reproducibility. Currently ignored by the Sonilo "
|
|
"service but kept for graph consistency.",
|
|
),
|
|
],
|
|
outputs=[IO.Audio.Output()],
|
|
hidden=[
|
|
IO.Hidden.auth_token_comfy_org,
|
|
IO.Hidden.api_key_comfy_org,
|
|
IO.Hidden.unique_id,
|
|
],
|
|
is_api_node=True,
|
|
price_badge=IO.PriceBadge(
|
|
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
|
|
expr="""
|
|
(
|
|
widgets.duration > 0
|
|
? {"type":"usd","usd": 0.005 * widgets.duration}
|
|
: {"type":"usd","usd": 0.005, "format":{"suffix":"/second"}}
|
|
)
|
|
""",
|
|
),
|
|
)
|
|
|
|
@classmethod
|
|
async def execute(
|
|
cls,
|
|
prompt: str,
|
|
duration: int = 0,
|
|
seed: int = 0,
|
|
) -> IO.NodeOutput:
|
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
|
form = aiohttp.FormData()
|
|
form.add_field("prompt", prompt)
|
|
if duration > 0:
|
|
form.add_field("duration", str(duration))
|
|
audio_bytes = await _stream_sonilo_music(
|
|
cls,
|
|
ApiEndpoint(path="/proxy/sonilo/t2m/generate", method="POST"),
|
|
form,
|
|
)
|
|
return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes))
|
|
|
|
|
|
async def _stream_sonilo_music(
|
|
cls: type[IO.ComfyNode],
|
|
endpoint: ApiEndpoint,
|
|
form: aiohttp.FormData,
|
|
) -> bytes:
|
|
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
|
|
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))
|
|
|
|
headers: dict[str, str] = {}
|
|
headers.update(get_auth_header(cls))
|
|
headers.update(endpoint.headers)
|
|
|
|
node_id = get_node_id(cls)
|
|
start_ts = time.monotonic()
|
|
last_chunk_status_ts = 0.0
|
|
audio_streams: dict[int, list[bytes]] = {}
|
|
title: str | None = None
|
|
|
|
timeout = aiohttp.ClientTimeout(total=1200.0, sock_read=300.0)
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
PromptServer.instance.send_progress_text("Status: Queued", node_id)
|
|
async with session.post(url, data=form, headers=headers) as resp:
|
|
if resp.status >= 400:
|
|
msg = await _extract_error_message(resp)
|
|
raise Exception(f"Sonilo API error ({resp.status}): {msg}")
|
|
|
|
while True:
|
|
if is_processing_interrupted():
|
|
raise ProcessingInterrupted("Task cancelled")
|
|
|
|
raw_line = await resp.content.readline()
|
|
if not raw_line:
|
|
break
|
|
|
|
line = raw_line.decode("utf-8").strip()
|
|
if not line:
|
|
continue
|
|
|
|
try:
|
|
evt = json.loads(line)
|
|
except json.JSONDecodeError:
|
|
logger.warning("Sonilo: skipping malformed NDJSON line")
|
|
continue
|
|
|
|
evt_type = evt.get("type")
|
|
if evt_type == "error":
|
|
code = evt.get("code", "UNKNOWN")
|
|
message = evt.get("message", "Unknown error")
|
|
raise Exception(f"Sonilo generation error ({code}): {message}")
|
|
if evt_type == "duration":
|
|
duration_sec = evt.get("duration_sec")
|
|
if duration_sec is not None:
|
|
PromptServer.instance.send_progress_text(
|
|
f"Status: Generating\nVideo duration: {duration_sec:.1f}s",
|
|
node_id,
|
|
)
|
|
elif evt_type in ("titles", "title"):
|
|
# v2m sends a "titles" list, t2m sends a scalar "title"
|
|
if evt_type == "titles":
|
|
titles = evt.get("titles", [])
|
|
if titles:
|
|
title = titles[0]
|
|
else:
|
|
title = evt.get("title") or title
|
|
if title:
|
|
PromptServer.instance.send_progress_text(
|
|
f"Status: Generating\nTitle: {title}",
|
|
node_id,
|
|
)
|
|
elif evt_type == "audio_chunk":
|
|
stream_idx = evt.get("stream_index", 0)
|
|
chunk_data = base64.b64decode(evt["data"])
|
|
|
|
if stream_idx not in audio_streams:
|
|
audio_streams[stream_idx] = []
|
|
audio_streams[stream_idx].append(chunk_data)
|
|
|
|
now = time.monotonic()
|
|
if now - last_chunk_status_ts >= 1.0:
|
|
total_chunks = sum(len(chunks) for chunks in audio_streams.values())
|
|
elapsed = int(now - start_ts)
|
|
status_lines = ["Status: Receiving audio"]
|
|
if title:
|
|
status_lines.append(f"Title: {title}")
|
|
status_lines.append(f"Chunks received: {total_chunks}")
|
|
status_lines.append(f"Time elapsed: {elapsed}s")
|
|
PromptServer.instance.send_progress_text("\n".join(status_lines), node_id)
|
|
last_chunk_status_ts = now
|
|
elif evt_type == "complete":
|
|
break
|
|
|
|
if not audio_streams:
|
|
raise Exception("Sonilo API returned no audio data.")
|
|
|
|
PromptServer.instance.send_progress_text("Status: Completed", node_id)
|
|
selected_stream = 0 if 0 in audio_streams else min(audio_streams)
|
|
return b"".join(audio_streams[selected_stream])
|
|
|
|
|
|
async def _extract_error_message(resp: aiohttp.ClientResponse) -> str:
|
|
"""Extract a human-readable error message from an HTTP error response."""
|
|
try:
|
|
error_body = await resp.json()
|
|
detail = error_body.get("detail", {})
|
|
if isinstance(detail, dict):
|
|
return detail.get("message", str(detail))
|
|
return str(detail)
|
|
except Exception:
|
|
return await resp.text()
|
|
|
|
|
|
class SoniloExtension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
|
return [SoniloVideoToMusic, SoniloTextToMusic]
|
|
|
|
|
|
async def comfy_entrypoint() -> SoniloExtension:
|
|
return SoniloExtension()
|