from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.wavespeed import ( FlashVSRRequest, TaskCreatedResponse, TaskResultResponse, SeedVR2ImageRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_video_output, poll_op, sync_op, upload_video_to_comfyapi, validate_container_format_is_mp4, validate_video_duration, upload_images_to_comfyapi, get_number_of_images, download_url_to_image_tensor, ) class WavespeedFlashVSRNode(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="WavespeedFlashVSRNode", display_name="FlashVSR Video Upscale", category="api node/video/WaveSpeed", description="Fast, high-quality video upscaler that " "boosts resolution and restores clarity for low-resolution or blurry footage.", inputs=[ IO.Video.Input("video"), IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]), ], outputs=[ IO.Video.Output(), ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, price_badge=IO.PriceBadge( depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]), expr=""" ( $price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032}; { "type":"usd", "usd": $lookup($price_for_1sec, widgets.target_resolution), "format":{"suffix": "/second", "approximate": true} } ) """, ), ) @classmethod async def execute( cls, video: Input.Video, target_resolution: str, ) -> IO.NodeOutput: validate_container_format_is_mp4(video) validate_video_duration(video, min_duration=5, max_duration=60 * 10) initial_res = await sync_op( cls, ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"), response_model=TaskCreatedResponse, data=FlashVSRRequest( target_resolution=target_resolution.lower(), video=await upload_video_to_comfyapi(cls, video), duration=video.get_duration(), ), ) if initial_res.code != 200: raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") final_response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), response_model=TaskResultResponse, status_extractor=lambda x: "failed" if x.data is None else x.data.status, poll_interval=10.0, max_poll_attempts=480, ) if final_response.code != 200: raise ValueError( f"Task processing failed with code={final_response.code} and message={final_response.message}" ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0])) class WavespeedImageUpscaleNode(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="WavespeedImageUpscaleNode", display_name="WaveSpeed Image Upscale", category="api node/image/WaveSpeed", description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.", inputs=[ IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]), IO.Image.Input("image"), IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]), ], outputs=[ IO.Image.Output(), ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, price_badge=IO.PriceBadge( depends_on=IO.PriceBadgeDepends(widgets=["model"]), expr=""" ( $prices := {"seedvr2": 0.01, "ultimate": 0.06}; {"type":"usd", "usd": $lookup($prices, widgets.model)} ) """, ), ) @classmethod async def execute( cls, model: str, image: Input.Image, target_resolution: str, ) -> IO.NodeOutput: if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") if model == "SeedVR2": model_path = "seedvr2/image" else: model_path = "ultimate-image-upscaler" initial_res = await sync_op( cls, ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"), response_model=TaskCreatedResponse, data=SeedVR2ImageRequest( target_resolution=target_resolution.lower(), image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], ), ) if initial_res.code != 200: raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") final_response = await poll_op( cls, ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), response_model=TaskResultResponse, status_extractor=lambda x: "failed" if x.data is None else x.data.status, poll_interval=10.0, max_poll_attempts=480, ) if final_response.code != 200: raise ValueError( f"Task processing failed with code={final_response.code} and message={final_response.message}" ) return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0])) class WavespeedExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ WavespeedFlashVSRNode, WavespeedImageUpscaleNode, ] async def comfy_entrypoint() -> WavespeedExtension: return WavespeedExtension()