From b1f5a3254b3c3602210c53b9c8a2d4a2e6413c16 Mon Sep 17 00:00:00 2001 From: onlyforthesis Date: Mon, 27 Apr 2026 10:00:29 +0800 Subject: [PATCH] Add WaveSpeed text-to-image node using Flux models Adds a new `WavespeedTextToImageNode` that generates images via WaveSpeed's fast Flux inference API (flux-dev, flux-dev-fp8, flux-schnell, flux-schnell-fp8). Also adds the corresponding `WavespeedTextToImageRequest` Pydantic model to the wavespeed API module. Co-Authored-By: Claude Sonnet 4.6 --- comfy_api_nodes/apis/wavespeed.py | 16 +++ comfy_api_nodes/nodes_wavespeed.py | 153 +++++++++++++++++++++++++++++ 2 files changed, 169 insertions(+) diff --git a/comfy_api_nodes/apis/wavespeed.py b/comfy_api_nodes/apis/wavespeed.py index 07a7bfa5d..c356d8522 100644 --- a/comfy_api_nodes/apis/wavespeed.py +++ b/comfy_api_nodes/apis/wavespeed.py @@ -1,6 +1,22 @@ +from typing import Optional + from pydantic import BaseModel, Field +class WavespeedTextToImageRequest(BaseModel): + prompt: str = Field(...) + width: int = Field(1024) + height: int = Field(1024) + num_inference_steps: int = Field(28) + guidance_scale: float = Field(3.5) + seed: int = Field(-1) + enable_safety_checker: bool = Field(True) + enable_sync_mode: bool = Field(False) + num_images: int = Field(1) + output_format: str = Field("jpeg") + negative_prompt: Optional[str] = Field(None) + + class SeedVR2ImageRequest(BaseModel): image: str = Field(...) target_resolution: str = Field(...) diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py index c59fafd3b..7d878321e 100644 --- a/comfy_api_nodes/nodes_wavespeed.py +++ b/comfy_api_nodes/nodes_wavespeed.py @@ -6,6 +6,7 @@ from comfy_api_nodes.apis.wavespeed import ( TaskCreatedResponse, TaskResultResponse, SeedVR2ImageRequest, + WavespeedTextToImageRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -165,10 +166,162 @@ class WavespeedImageUpscaleNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0])) +_TEXT_TO_IMAGE_MODELS = [ + "wavespeed-ai/flux-dev", + "wavespeed-ai/flux-dev-fp8", + "wavespeed-ai/flux-schnell", + "wavespeed-ai/flux-schnell-fp8", +] + +_MODEL_ENDPOINT = { + "wavespeed-ai/flux-dev": "flux-dev", + "wavespeed-ai/flux-dev-fp8": "flux-dev-fp8", + "wavespeed-ai/flux-schnell": "flux-schnell", + "wavespeed-ai/flux-schnell-fp8": "flux-schnell-fp8", +} + +_SCHNELL_MODELS = {"wavespeed-ai/flux-schnell", "wavespeed-ai/flux-schnell-fp8"} + + +class WavespeedTextToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WavespeedTextToImageNode", + display_name="WaveSpeed Text to Image", + category="api node/image/WaveSpeed", + description="Generate images from text prompts using WaveSpeed's fast Flux inference.", + inputs=[ + IO.Combo.Input("model", options=_TEXT_TO_IMAGE_MODELS), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt describing the image to generate.", + ), + IO.Int.Input( + "width", + default=1024, + min=256, + max=2048, + step=64, + tooltip="Image width in pixels.", + ), + IO.Int.Input( + "height", + default=1024, + min=256, + max=2048, + step=64, + tooltip="Image height in pixels.", + ), + IO.Int.Input( + "steps", + default=28, + min=1, + max=50, + tooltip="Number of inference steps. Schnell models work well with 4 steps.", + ), + IO.Float.Input( + "guidance_scale", + default=3.5, + min=0.0, + max=20.0, + step=0.1, + tooltip="Guidance scale (CFG). Not used for Schnell models.", + advanced=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed for reproducibility. Use -1 for random.", + ), + IO.Boolean.Input( + "safety_checker", + default=True, + tooltip="Enable the safety checker.", + advanced=True, + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $prices := { + "wavespeed-ai/flux-dev": 0.003, + "wavespeed-ai/flux-dev-fp8": 0.003, + "wavespeed-ai/flux-schnell": 0.001, + "wavespeed-ai/flux-schnell-fp8": 0.001 + }; + {"type":"usd", "usd": $lookup($prices, widgets.model)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + width: int, + height: int, + steps: int, + guidance_scale: float, + seed: int, + safety_checker: bool, + ) -> IO.NodeOutput: + endpoint_name = _MODEL_ENDPOINT[model] + is_schnell = model in _SCHNELL_MODELS + initial_res = await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{endpoint_name}", + method="POST", + ), + response_model=TaskCreatedResponse, + data=WavespeedTextToImageRequest( + prompt=prompt, + width=width, + height=height, + num_inference_steps=steps, + guidance_scale=1.0 if is_schnell else guidance_scale, + seed=seed, + enable_safety_checker=safety_checker, + ), + ) + if initial_res.code != 200: + raise ValueError(f"Task creation failed with code={initial_res.code} and message={initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), + response_model=TaskResultResponse, + status_extractor=lambda x: "failed" if x.data is None else x.data.status, + poll_interval=3.0, + max_poll_attempts=200, + ) + if final_response.code != 200: + raise ValueError( + f"Task processing failed with code={final_response.code} and message={final_response.message}" + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0])) + + class WavespeedExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ + WavespeedTextToImageNode, WavespeedFlashVSRNode, WavespeedImageUpscaleNode, ]