mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-24 07:57:29 +08:00
Add WaveSpeed text-to-image node using Flux models
Adds a new `WavespeedTextToImageNode` that generates images via WaveSpeed's fast Flux inference API (flux-dev, flux-dev-fp8, flux-schnell, flux-schnell-fp8). Also adds the corresponding `WavespeedTextToImageRequest` Pydantic model to the wavespeed API module. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
7385eb2800
commit
b1f5a3254b
@ -1,6 +1,22 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class WavespeedTextToImageRequest(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
width: int = Field(1024)
|
||||||
|
height: int = Field(1024)
|
||||||
|
num_inference_steps: int = Field(28)
|
||||||
|
guidance_scale: float = Field(3.5)
|
||||||
|
seed: int = Field(-1)
|
||||||
|
enable_safety_checker: bool = Field(True)
|
||||||
|
enable_sync_mode: bool = Field(False)
|
||||||
|
num_images: int = Field(1)
|
||||||
|
output_format: str = Field("jpeg")
|
||||||
|
negative_prompt: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class SeedVR2ImageRequest(BaseModel):
|
class SeedVR2ImageRequest(BaseModel):
|
||||||
image: str = Field(...)
|
image: str = Field(...)
|
||||||
target_resolution: str = Field(...)
|
target_resolution: str = Field(...)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from comfy_api_nodes.apis.wavespeed import (
|
|||||||
TaskCreatedResponse,
|
TaskCreatedResponse,
|
||||||
TaskResultResponse,
|
TaskResultResponse,
|
||||||
SeedVR2ImageRequest,
|
SeedVR2ImageRequest,
|
||||||
|
WavespeedTextToImageRequest,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -165,10 +166,162 @@ class WavespeedImageUpscaleNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0]))
|
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0]))
|
||||||
|
|
||||||
|
|
||||||
|
_TEXT_TO_IMAGE_MODELS = [
|
||||||
|
"wavespeed-ai/flux-dev",
|
||||||
|
"wavespeed-ai/flux-dev-fp8",
|
||||||
|
"wavespeed-ai/flux-schnell",
|
||||||
|
"wavespeed-ai/flux-schnell-fp8",
|
||||||
|
]
|
||||||
|
|
||||||
|
_MODEL_ENDPOINT = {
|
||||||
|
"wavespeed-ai/flux-dev": "flux-dev",
|
||||||
|
"wavespeed-ai/flux-dev-fp8": "flux-dev-fp8",
|
||||||
|
"wavespeed-ai/flux-schnell": "flux-schnell",
|
||||||
|
"wavespeed-ai/flux-schnell-fp8": "flux-schnell-fp8",
|
||||||
|
}
|
||||||
|
|
||||||
|
_SCHNELL_MODELS = {"wavespeed-ai/flux-schnell", "wavespeed-ai/flux-schnell-fp8"}
|
||||||
|
|
||||||
|
|
||||||
|
class WavespeedTextToImageNode(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="WavespeedTextToImageNode",
|
||||||
|
display_name="WaveSpeed Text to Image",
|
||||||
|
category="api node/image/WaveSpeed",
|
||||||
|
description="Generate images from text prompts using WaveSpeed's fast Flux inference.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model", options=_TEXT_TO_IMAGE_MODELS),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text prompt describing the image to generate.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"width",
|
||||||
|
default=1024,
|
||||||
|
min=256,
|
||||||
|
max=2048,
|
||||||
|
step=64,
|
||||||
|
tooltip="Image width in pixels.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"height",
|
||||||
|
default=1024,
|
||||||
|
min=256,
|
||||||
|
max=2048,
|
||||||
|
step=64,
|
||||||
|
tooltip="Image height in pixels.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"steps",
|
||||||
|
default=28,
|
||||||
|
min=1,
|
||||||
|
max=50,
|
||||||
|
tooltip="Number of inference steps. Schnell models work well with 4 steps.",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"guidance_scale",
|
||||||
|
default=3.5,
|
||||||
|
min=0.0,
|
||||||
|
max=20.0,
|
||||||
|
step=0.1,
|
||||||
|
tooltip="Guidance scale (CFG). Not used for Schnell models.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed for reproducibility. Use -1 for random.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"safety_checker",
|
||||||
|
default=True,
|
||||||
|
tooltip="Enable the safety checker.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Image.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$prices := {
|
||||||
|
"wavespeed-ai/flux-dev": 0.003,
|
||||||
|
"wavespeed-ai/flux-dev-fp8": 0.003,
|
||||||
|
"wavespeed-ai/flux-schnell": 0.001,
|
||||||
|
"wavespeed-ai/flux-schnell-fp8": 0.001
|
||||||
|
};
|
||||||
|
{"type":"usd", "usd": $lookup($prices, widgets.model)}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
steps: int,
|
||||||
|
guidance_scale: float,
|
||||||
|
seed: int,
|
||||||
|
safety_checker: bool,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
endpoint_name = _MODEL_ENDPOINT[model]
|
||||||
|
is_schnell = model in _SCHNELL_MODELS
|
||||||
|
initial_res = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(
|
||||||
|
path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{endpoint_name}",
|
||||||
|
method="POST",
|
||||||
|
),
|
||||||
|
response_model=TaskCreatedResponse,
|
||||||
|
data=WavespeedTextToImageRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
num_inference_steps=steps,
|
||||||
|
guidance_scale=1.0 if is_schnell else guidance_scale,
|
||||||
|
seed=seed,
|
||||||
|
enable_safety_checker=safety_checker,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if initial_res.code != 200:
|
||||||
|
raise ValueError(f"Task creation failed with code={initial_res.code} and message={initial_res.message}")
|
||||||
|
final_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
|
||||||
|
response_model=TaskResultResponse,
|
||||||
|
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
|
||||||
|
poll_interval=3.0,
|
||||||
|
max_poll_attempts=200,
|
||||||
|
)
|
||||||
|
if final_response.code != 200:
|
||||||
|
raise ValueError(
|
||||||
|
f"Task processing failed with code={final_response.code} and message={final_response.message}"
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0]))
|
||||||
|
|
||||||
|
|
||||||
class WavespeedExtension(ComfyExtension):
|
class WavespeedExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
|
WavespeedTextToImageNode,
|
||||||
WavespeedFlashVSRNode,
|
WavespeedFlashVSRNode,
|
||||||
WavespeedImageUpscaleNode,
|
WavespeedImageUpscaleNode,
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user