add Flux2MaxImage API Node (#11420)

This commit is contained in:
Alexander Piskun 2025-12-19 20:02:49 +02:00 committed by GitHub
parent 894802b0f9
commit 5b4d0664c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,10 +1,8 @@
from inspect import cleandoc
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl_api import ( from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest, BFLFluxExpandImageRequest,
BFLFluxFillImageRequest, BFLFluxFillImageRequest,
@ -28,7 +26,7 @@ from comfy_api_nodes.util import (
) )
def convert_mask_to_image(mask: torch.Tensor): def convert_mask_to_image(mask: Input.Image):
""" """
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
""" """
@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor):
class FluxProUltraImageNode(IO.ComfyNode): class FluxProUltraImageNode(IO.ComfyNode):
"""
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
node_id="FluxProUltraImageNode", node_id="FluxProUltraImageNode",
display_name="Flux 1.1 [pro] Ultra Image", display_name="Flux 1.1 [pro] Ultra Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -117,7 +112,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt_upsampling: bool = False, prompt_upsampling: bool = False,
raw: bool = False, raw: bool = False,
seed: int = 0, seed: int = 0,
image_prompt: torch.Tensor | None = None, image_prompt: Input.Image | None = None,
image_prompt_strength: float = 0.1, image_prompt_strength: float = 0.1,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if image_prompt is None: if image_prompt is None:
@ -155,9 +150,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
class FluxKontextProImageNode(IO.ComfyNode): class FluxKontextProImageNode(IO.ComfyNode):
"""
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -165,7 +157,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
node_id=cls.NODE_ID, node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME, display_name=cls.DISPLAY_NAME,
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -231,7 +223,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
guidance: float, guidance: float,
steps: int, steps: int,
input_image: torch.Tensor | None = None, input_image: Input.Image | None = None,
seed=0, seed=0,
prompt_upsampling=False, prompt_upsampling=False,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
@ -271,20 +263,14 @@ class FluxKontextProImageNode(IO.ComfyNode):
class FluxKontextMaxImageNode(FluxKontextProImageNode): class FluxKontextMaxImageNode(FluxKontextProImageNode):
"""
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
"""
DESCRIPTION = cleandoc(__doc__ or "") DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio."
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
NODE_ID = "FluxKontextMaxImageNode" NODE_ID = "FluxKontextMaxImageNode"
DISPLAY_NAME = "Flux.1 Kontext [max] Image" DISPLAY_NAME = "Flux.1 Kontext [max] Image"
class FluxProExpandNode(IO.ComfyNode): class FluxProExpandNode(IO.ComfyNode):
"""
Outpaints image based on prompt.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -292,7 +278,7 @@ class FluxProExpandNode(IO.ComfyNode):
node_id="FluxProExpandNode", node_id="FluxProExpandNode",
display_name="Flux.1 Expand Image", display_name="Flux.1 Expand Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Outpaints image based on prompt.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.String.Input( IO.String.Input(
@ -371,7 +357,7 @@ class FluxProExpandNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
image: torch.Tensor, image: Input.Image,
prompt: str, prompt: str,
prompt_upsampling: bool, prompt_upsampling: bool,
top: int, top: int,
@ -418,9 +404,6 @@ class FluxProExpandNode(IO.ComfyNode):
class FluxProFillNode(IO.ComfyNode): class FluxProFillNode(IO.ComfyNode):
"""
Inpaints image based on mask and prompt.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
@ -428,7 +411,7 @@ class FluxProFillNode(IO.ComfyNode):
node_id="FluxProFillNode", node_id="FluxProFillNode",
display_name="Flux.1 Fill Image", display_name="Flux.1 Fill Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Inpaints image based on mask and prompt.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.Mask.Input("mask"), IO.Mask.Input("mask"),
@ -480,8 +463,8 @@ class FluxProFillNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
image: torch.Tensor, image: Input.Image,
mask: torch.Tensor, mask: Input.Image,
prompt: str, prompt: str,
prompt_upsampling: bool, prompt_upsampling: bool,
steps: int, steps: int,
@ -525,11 +508,15 @@ class FluxProFillNode(IO.ComfyNode):
class Flux2ProImageNode(IO.ComfyNode): class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
DISPLAY_NAME = "Flux.2 [pro] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="Flux2ProImageNode", node_id=cls.NODE_ID,
display_name="Flux.2 [pro] Image", display_name=cls.DISPLAY_NAME,
category="api node/image/BFL", category="api node/image/BFL",
description="Generates images synchronously based on prompt and resolution.", description="Generates images synchronously based on prompt and resolution.",
inputs=[ inputs=[
@ -563,12 +550,11 @@ class Flux2ProImageNode(IO.ComfyNode):
), ),
IO.Boolean.Input( IO.Boolean.Input(
"prompt_upsampling", "prompt_upsampling",
default=False, default=True,
tooltip="Whether to perform upsampling on the prompt. " tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, " "If active, automatically modifies the prompt for more creative generation.",
"but results are nondeterministic (same seed will not produce exactly the same result).",
), ),
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."),
], ],
outputs=[IO.Image.Output()], outputs=[IO.Image.Output()],
hidden=[ hidden=[
@ -587,7 +573,7 @@ class Flux2ProImageNode(IO.ComfyNode):
height: int, height: int,
seed: int, seed: int,
prompt_upsampling: bool, prompt_upsampling: bool,
images: torch.Tensor | None = None, images: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
reference_images = {} reference_images = {}
if images is not None: if images is not None:
@ -598,7 +584,7 @@ class Flux2ProImageNode(IO.ComfyNode):
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), ApiEndpoint(path=cls.API_ENDPOINT, method="POST"),
response_model=BFLFluxProGenerateResponse, response_model=BFLFluxProGenerateResponse,
data=Flux2ProGenerateRequest( data=Flux2ProGenerateRequest(
prompt=prompt, prompt=prompt,
@ -632,6 +618,13 @@ class Flux2ProImageNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2MaxImageNode(Flux2ProImageNode):
NODE_ID = "Flux2MaxImageNode"
DISPLAY_NAME = "Flux.2 [max] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
class BFLExtension(ComfyExtension): class BFLExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -642,6 +635,7 @@ class BFLExtension(ComfyExtension):
FluxProExpandNode, FluxProExpandNode,
FluxProFillNode, FluxProFillNode,
Flux2ProImageNode, Flux2ProImageNode,
Flux2MaxImageNode,
] ]