[Partner Nodes] feat: add Flux Virtual Try-On and Erase nodes (#14207)

This commit is contained in:
Alexander Piskun 2026-06-01 17:12:12 +03:00 committed by GitHub
parent 462c27fdb2
commit af58c5e674
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 243 additions and 107 deletions

View File

@ -1,71 +1,71 @@
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field, confloat, conint
class BFLOutputFormat(str, Enum):
png = 'png'
jpeg = 'jpeg'
from pydantic import BaseModel, Field
class BFLFluxExpandImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image')
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image')
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image')
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
top: int = Field(...)
bottom: int = Field(...)
left: int = Field(...)
right: int = Field(...)
steps: int = Field(...)
guidance: float = Field(...)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image: str = Field(None, description="A Base64-encoded string representing the image you wish to expand")
class BFLFluxFillImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
steps: int = Field(...)
guidance: float = Field(...)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image: str = Field(
None, description="Base64-encoded string representing the image to modify. Can contain alpha mask if desired.",
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
mask: str = Field(
None, description="Base64-encoded string representing the mask of the areas you wish to modify."
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
class BFLFluxEraseRequest(BaseModel):
image: str = Field(..., description="A Base64-encoded string representing the image to erase from.")
mask: str = Field(
...,
description="A Base64-encoded black/white mask matching the input dimensions; "
"white (255) marks areas to remove, black (0) marks areas to preserve.",
)
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.')
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
dilate_pixels: int = Field(10)
output_format: str = Field("png")
class BFLFluxVTORequest(BaseModel):
prompt: str = Field(
..., description="Natural-language styling instruction. Required field, but may be an empty string."
)
person: str = Field(..., description="A Base64-encoded string representing the person image.")
garment: str = Field(..., description="A Base64-encoded string representing the garment reference image.")
seed: int | None = Field(None)
safety_tolerance: int = Field(5)
output_format: str = Field("png")
class BFLFluxProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.')
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
# None, description='Blend between the prompt and the image prompt.'
# )
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
width: int = Field(1024, description="Must be a multiple of 32.")
height: int = Field(768, description="Must be a multiple of 32.")
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
class Flux2ProGenerateRequest(BaseModel):
@ -83,55 +83,37 @@ class Flux2ProGenerateRequest(BaseModel):
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
safety_tolerance: int | None = Field(
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
)
output_format: str | None = Field(
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
)
safety_tolerance: int = Field(5)
output_format: str = Field("png")
class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
prompt: str = Field(...)
input_image: str | None = Field(None, description="Image to edit in base64 format")
seed: int | None = Field(None)
guidance: float = Field(...)
steps: int = Field(...)
safety_tolerance: int = Field(2)
output_format: str = Field("png")
aspect_ratio: str | None = Field(None)
prompt_upsampling: bool | None = Field(None)
class BFLFluxProUltraGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
None, description='Blend between the prompt and the image prompt.'
)
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
aspect_ratio: str | None = Field(None)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
raw: bool | None = Field(None)
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
image_prompt_strength: float | None = Field(None)
class BFLFluxProGenerateResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
polling_url: str = Field(..., description="URL to poll for the generation result.")
id: str = Field(...)
polling_url: str = Field(...)
cost: float | None = Field(None, description="Price in cents")
@ -145,7 +127,7 @@ class BFLStatus(str, Enum):
class BFLFluxStatusResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
status: BFLStatus = Field(..., description="The status of the task.")
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
id: str = Field(...)
status: BFLStatus = Field(...)
result: dict[str, Any] | None = Field(None)
progress: float | None = Field(None, ge=0.0, le=1.0)

View File

@ -4,17 +4,20 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl import (
BFLFluxEraseRequest,
BFLFluxExpandImageRequest,
BFLFluxFillImageRequest,
BFLFluxKontextProGenerateRequest,
BFLFluxProGenerateResponse,
BFLFluxProUltraGenerateRequest,
BFLFluxStatusResponse,
BFLFluxVTORequest,
BFLStatus,
Flux2ProGenerateRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor,
get_number_of_images,
poll_op,
@ -22,19 +25,11 @@ from comfy_api_nodes.util import (
sync_op,
tensor_to_base64_string,
validate_aspect_ratio_string,
validate_image_dimensions,
validate_string,
)
def convert_mask_to_image(mask: Input.Image):
"""
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
"""
mask = mask.unsqueeze(-1)
mask = torch.cat([mask] * 3, dim=-1)
return mask
class FluxProUltraImageNode(IO.ComfyNode):
@classmethod
@ -519,6 +514,163 @@ class FluxProFillNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxEraseNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxEraseNode",
display_name="Flux Erase Image",
category="image/partner/BFL",
description="Removes the masked object from an image and reconstructs the background. "
"Paint the mask over what you want to erase.",
inputs=[
IO.Image.Input("image"),
IO.Mask.Input("mask", tooltip="White areas are removed; black areas are preserved."),
IO.Int.Input(
"dilate_pixels",
default=10,
min=0,
max=25,
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"range_usd","min_usd":0.03,"max_usd":0.06,"format":{"approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
mask: Input.Image,
dilate_pixels: int = 10,
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=256, min_height=256)
mask = resize_mask_to_image(mask, image)
mask = tensor_to_base64_string(convert_mask_to_image(mask))
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/erase-v1", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxEraseRequest(
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
mask=mask,
dilate_pixels=dilate_pixels,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxVTONode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxVTONode",
display_name="Flux Virtual Try-On",
category="image/partner/BFL",
description="Virtual try-on: dresses the person in the provided garment.",
inputs=[
IO.Image.Input("person", tooltip="Image of the person to dress."),
IO.Image.Input("garment", tooltip="Image of the garment to apply."),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional natural-language styling instruction (e.g. how the garment should fit).",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"range_usd","min_usd":0.0375,"max_usd":0.075,"format":{"approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
person: Input.Image,
garment: Input.Image,
prompt: str = "",
seed: int = 0,
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/vto-v1", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxVTORequest(
prompt=prompt,
person=tensor_to_base64_string(person[:, :, :, :3]),
garment=tensor_to_base64_string(garment[:, :, :, :3]),
seed=seed,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
@ -853,6 +1005,8 @@ class BFLExtension(ComfyExtension):
FluxKontextMaxImageNode,
FluxProExpandNode,
FluxProFillNode,
FluxEraseNode,
FluxVTONode,
Flux2ProImageNode,
Flux2MaxImageNode,
Flux2ImageNode,