ComfyUI/comfy_api_nodes/nodes_wavespeed.py
onlyforthesis b1f5a3254b Add WaveSpeed text-to-image node using Flux models
Adds a new `WavespeedTextToImageNode` that generates images via
WaveSpeed's fast Flux inference API (flux-dev, flux-dev-fp8,
flux-schnell, flux-schnell-fp8).

Also adds the corresponding `WavespeedTextToImageRequest` Pydantic
model to the wavespeed API module.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-27 10:00:29 +08:00

332 lines
12 KiB
Python

from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.wavespeed import (
FlashVSRRequest,
TaskCreatedResponse,
TaskResultResponse,
SeedVR2ImageRequest,
WavespeedTextToImageRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output,
poll_op,
sync_op,
upload_video_to_comfyapi,
validate_container_format_is_mp4,
validate_video_duration,
upload_images_to_comfyapi,
get_number_of_images,
download_url_to_image_tensor,
)
class WavespeedFlashVSRNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="WavespeedFlashVSRNode",
display_name="FlashVSR Video Upscale",
category="api node/video/WaveSpeed",
description="Fast, high-quality video upscaler that "
"boosts resolution and restores clarity for low-resolution or blurry footage.",
inputs=[
IO.Video.Input("video"),
IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]),
expr="""
(
$price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032};
{
"type":"usd",
"usd": $lookup($price_for_1sec, widgets.target_resolution),
"format":{"suffix": "/second", "approximate": true}
}
)
""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
target_resolution: str,
) -> IO.NodeOutput:
validate_container_format_is_mp4(video)
validate_video_duration(video, min_duration=5, max_duration=60 * 10)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"),
response_model=TaskCreatedResponse,
data=FlashVSRRequest(
target_resolution=target_resolution.lower(),
video=await upload_video_to_comfyapi(cls, video),
duration=video.get_duration(),
),
)
if initial_res.code != 200:
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
response_model=TaskResultResponse,
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
poll_interval=10.0,
max_poll_attempts=480,
)
if final_response.code != 200:
raise ValueError(
f"Task processing failed with code={final_response.code} and message={final_response.message}"
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0]))
class WavespeedImageUpscaleNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="WavespeedImageUpscaleNode",
display_name="WaveSpeed Image Upscale",
category="api node/image/WaveSpeed",
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
inputs=[
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),
IO.Image.Input("image"),
IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$prices := {"seedvr2": 0.01, "ultimate": 0.06};
{"type":"usd", "usd": $lookup($prices, widgets.model)}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
image: Input.Image,
target_resolution: str,
) -> IO.NodeOutput:
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
if model == "SeedVR2":
model_path = "seedvr2/image"
else:
model_path = "ultimate-image-upscaler"
initial_res = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"),
response_model=TaskCreatedResponse,
data=SeedVR2ImageRequest(
target_resolution=target_resolution.lower(),
image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0],
),
)
if initial_res.code != 200:
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
response_model=TaskResultResponse,
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
poll_interval=10.0,
max_poll_attempts=480,
)
if final_response.code != 200:
raise ValueError(
f"Task processing failed with code={final_response.code} and message={final_response.message}"
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0]))
_TEXT_TO_IMAGE_MODELS = [
"wavespeed-ai/flux-dev",
"wavespeed-ai/flux-dev-fp8",
"wavespeed-ai/flux-schnell",
"wavespeed-ai/flux-schnell-fp8",
]
_MODEL_ENDPOINT = {
"wavespeed-ai/flux-dev": "flux-dev",
"wavespeed-ai/flux-dev-fp8": "flux-dev-fp8",
"wavespeed-ai/flux-schnell": "flux-schnell",
"wavespeed-ai/flux-schnell-fp8": "flux-schnell-fp8",
}
_SCHNELL_MODELS = {"wavespeed-ai/flux-schnell", "wavespeed-ai/flux-schnell-fp8"}
class WavespeedTextToImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="WavespeedTextToImageNode",
display_name="WaveSpeed Text to Image",
category="api node/image/WaveSpeed",
description="Generate images from text prompts using WaveSpeed's fast Flux inference.",
inputs=[
IO.Combo.Input("model", options=_TEXT_TO_IMAGE_MODELS),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt describing the image to generate.",
),
IO.Int.Input(
"width",
default=1024,
min=256,
max=2048,
step=64,
tooltip="Image width in pixels.",
),
IO.Int.Input(
"height",
default=1024,
min=256,
max=2048,
step=64,
tooltip="Image height in pixels.",
),
IO.Int.Input(
"steps",
default=28,
min=1,
max=50,
tooltip="Number of inference steps. Schnell models work well with 4 steps.",
),
IO.Float.Input(
"guidance_scale",
default=3.5,
min=0.0,
max=20.0,
step=0.1,
tooltip="Guidance scale (CFG). Not used for Schnell models.",
advanced=True,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed for reproducibility. Use -1 for random.",
),
IO.Boolean.Input(
"safety_checker",
default=True,
tooltip="Enable the safety checker.",
advanced=True,
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$prices := {
"wavespeed-ai/flux-dev": 0.003,
"wavespeed-ai/flux-dev-fp8": 0.003,
"wavespeed-ai/flux-schnell": 0.001,
"wavespeed-ai/flux-schnell-fp8": 0.001
};
{"type":"usd", "usd": $lookup($prices, widgets.model)}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
width: int,
height: int,
steps: int,
guidance_scale: float,
seed: int,
safety_checker: bool,
) -> IO.NodeOutput:
endpoint_name = _MODEL_ENDPOINT[model]
is_schnell = model in _SCHNELL_MODELS
initial_res = await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{endpoint_name}",
method="POST",
),
response_model=TaskCreatedResponse,
data=WavespeedTextToImageRequest(
prompt=prompt,
width=width,
height=height,
num_inference_steps=steps,
guidance_scale=1.0 if is_schnell else guidance_scale,
seed=seed,
enable_safety_checker=safety_checker,
),
)
if initial_res.code != 200:
raise ValueError(f"Task creation failed with code={initial_res.code} and message={initial_res.message}")
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
response_model=TaskResultResponse,
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
poll_interval=3.0,
max_poll_attempts=200,
)
if final_response.code != 200:
raise ValueError(
f"Task processing failed with code={final_response.code} and message={final_response.message}"
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0]))
class WavespeedExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
WavespeedTextToImageNode,
WavespeedFlashVSRNode,
WavespeedImageUpscaleNode,
]
async def comfy_entrypoint() -> WavespeedExtension:
return WavespeedExtension()