[Partner Nodes] feat: add Flux Virtual Try-On and Erase nodes (#14207)

This commit is contained in:
Alexander Piskun 2026-06-01 17:12:12 +03:00 committed by GitHub
parent 462c27fdb2
commit af58c5e674
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 243 additions and 107 deletions

View File

@ -1,71 +1,71 @@
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any
from pydantic import BaseModel, Field, confloat, conint from pydantic import BaseModel, Field
class BFLOutputFormat(str, Enum):
png = 'png'
jpeg = 'jpeg'
class BFLFluxExpandImageRequest(BaseModel): class BFLFluxExpandImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.') prompt: str = Field(...)
prompt_upsampling: Optional[bool] = Field( prompt_upsampling: bool | None = Field(None)
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' seed: int | None = Field(None)
) top: int = Field(...)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.') bottom: int = Field(...)
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image') left: int = Field(...)
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image') right: int = Field(...)
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image') steps: int = Field(...)
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image') guidance: float = Field(...)
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') safety_tolerance: int = Field(6)
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process') output_format: str = Field("png")
safety_tolerance: Optional[conint(ge=0, le=6)] = Field( image: str = Field(None, description="A Base64-encoded string representing the image you wish to expand")
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
class BFLFluxFillImageRequest(BaseModel): class BFLFluxFillImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.') prompt: str = Field(...)
prompt_upsampling: Optional[bool] = Field( prompt_upsampling: bool | None = Field(None)
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' seed: int | None = Field(None)
steps: int = Field(...)
guidance: float = Field(...)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image: str = Field(
None, description="Base64-encoded string representing the image to modify. Can contain alpha mask if desired.",
) )
seed: Optional[int] = Field(None, description='The seed value for reproducibility.') mask: str = Field(
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') None, description="Base64-encoded string representing the mask of the areas you wish to modify."
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
) )
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
class BFLFluxEraseRequest(BaseModel):
image: str = Field(..., description="A Base64-encoded string representing the image to erase from.")
mask: str = Field(
...,
description="A Base64-encoded black/white mask matching the input dimensions; "
"white (255) marks areas to remove, black (0) marks areas to preserve.",
) )
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.') dilate_pixels: int = Field(10)
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.') output_format: str = Field("png")
class BFLFluxVTORequest(BaseModel):
prompt: str = Field(
..., description="Natural-language styling instruction. Required field, but may be an empty string."
)
person: str = Field(..., description="A Base64-encoded string representing the person image.")
garment: str = Field(..., description="A Base64-encoded string representing the garment reference image.")
seed: int | None = Field(None)
safety_tolerance: int = Field(5)
output_format: str = Field("png")
class BFLFluxProGenerateRequest(BaseModel): class BFLFluxProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.') prompt: str = Field(...)
prompt_upsampling: Optional[bool] = Field( prompt_upsampling: bool | None = Field(None)
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' seed: int | None = Field(None)
) width: int = Field(1024, description="Must be a multiple of 32.")
seed: Optional[int] = Field(None, description='The seed value for reproducibility.') height: int = Field(768, description="Must be a multiple of 32.")
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.') safety_tolerance: int = Field(6)
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.') output_format: str = Field("png")
safety_tolerance: Optional[conint(ge=0, le=6)] = Field( image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
# None, description='Blend between the prompt and the image prompt.'
# )
class Flux2ProGenerateRequest(BaseModel): class Flux2ProGenerateRequest(BaseModel):
@ -83,55 +83,37 @@ class Flux2ProGenerateRequest(BaseModel):
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation") input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation") input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation") input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
safety_tolerance: int | None = Field( safety_tolerance: int = Field(5)
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5 output_format: str = Field("png")
)
output_format: str | None = Field(
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
)
class BFLFluxKontextProGenerateRequest(BaseModel): class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.') prompt: str = Field(...)
input_image: Optional[str] = Field(None, description='Image to edit in base64 format') input_image: str | None = Field(None, description="Image to edit in base64 format")
seed: Optional[int] = Field(None, description='The seed value for reproducibility.') seed: int | None = Field(None)
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process') guidance: float = Field(...)
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process') steps: int = Field(...)
safety_tolerance: Optional[conint(ge=0, le=2)] = Field( safety_tolerance: int = Field(2)
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.' output_format: str = Field("png")
) aspect_ratio: str | None = Field(None)
output_format: Optional[BFLOutputFormat] = Field( prompt_upsampling: bool | None = Field(None)
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
class BFLFluxProUltraGenerateRequest(BaseModel): class BFLFluxProUltraGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.') prompt: str = Field(...)
prompt_upsampling: Optional[bool] = Field( prompt_upsampling: bool | None = Field(None)
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' seed: int | None = Field(None)
) aspect_ratio: str | None = Field(None)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.') safety_tolerance: int = Field(6)
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.') output_format: str = Field("png")
safety_tolerance: Optional[conint(ge=0, le=6)] = Field( raw: bool | None = Field(None)
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
) image_prompt_strength: float | None = Field(None)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
None, description='Blend between the prompt and the image prompt.'
)
class BFLFluxProGenerateResponse(BaseModel): class BFLFluxProGenerateResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.") id: str = Field(...)
polling_url: str = Field(..., description="URL to poll for the generation result.") polling_url: str = Field(...)
cost: float | None = Field(None, description="Price in cents") cost: float | None = Field(None, description="Price in cents")
@ -145,7 +127,7 @@ class BFLStatus(str, Enum):
class BFLFluxStatusResponse(BaseModel): class BFLFluxStatusResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.") id: str = Field(...)
status: BFLStatus = Field(..., description="The status of the task.") status: BFLStatus = Field(...)
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).") result: dict[str, Any] | None = Field(None)
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0) progress: float | None = Field(None, ge=0.0, le=1.0)

View File

@ -4,17 +4,20 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl import ( from comfy_api_nodes.apis.bfl import (
BFLFluxEraseRequest,
BFLFluxExpandImageRequest, BFLFluxExpandImageRequest,
BFLFluxFillImageRequest, BFLFluxFillImageRequest,
BFLFluxKontextProGenerateRequest, BFLFluxKontextProGenerateRequest,
BFLFluxProGenerateResponse, BFLFluxProGenerateResponse,
BFLFluxProUltraGenerateRequest, BFLFluxProUltraGenerateRequest,
BFLFluxStatusResponse, BFLFluxStatusResponse,
BFLFluxVTORequest,
BFLStatus, BFLStatus,
Flux2ProGenerateRequest, Flux2ProGenerateRequest,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor, download_url_to_image_tensor,
get_number_of_images, get_number_of_images,
poll_op, poll_op,
@ -22,19 +25,11 @@ from comfy_api_nodes.util import (
sync_op, sync_op,
tensor_to_base64_string, tensor_to_base64_string,
validate_aspect_ratio_string, validate_aspect_ratio_string,
validate_image_dimensions,
validate_string, validate_string,
) )
def convert_mask_to_image(mask: Input.Image):
"""
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
"""
mask = mask.unsqueeze(-1)
mask = torch.cat([mask] * 3, dim=-1)
return mask
class FluxProUltraImageNode(IO.ComfyNode): class FluxProUltraImageNode(IO.ComfyNode):
@classmethod @classmethod
@ -519,6 +514,163 @@ class FluxProFillNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxEraseNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxEraseNode",
display_name="Flux Erase Image",
category="image/partner/BFL",
description="Removes the masked object from an image and reconstructs the background. "
"Paint the mask over what you want to erase.",
inputs=[
IO.Image.Input("image"),
IO.Mask.Input("mask", tooltip="White areas are removed; black areas are preserved."),
IO.Int.Input(
"dilate_pixels",
default=10,
min=0,
max=25,
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"range_usd","min_usd":0.03,"max_usd":0.06,"format":{"approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
mask: Input.Image,
dilate_pixels: int = 10,
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=256, min_height=256)
mask = resize_mask_to_image(mask, image)
mask = tensor_to_base64_string(convert_mask_to_image(mask))
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/erase-v1", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxEraseRequest(
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
mask=mask,
dilate_pixels=dilate_pixels,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxVTONode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxVTONode",
display_name="Flux Virtual Try-On",
category="image/partner/BFL",
description="Virtual try-on: dresses the person in the provided garment.",
inputs=[
IO.Image.Input("person", tooltip="Image of the person to dress."),
IO.Image.Input("garment", tooltip="Image of the garment to apply."),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional natural-language styling instruction (e.g. how the garment should fit).",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"range_usd","min_usd":0.0375,"max_usd":0.075,"format":{"approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
person: Input.Image,
garment: Input.Image,
prompt: str = "",
seed: int = 0,
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/vto-v1", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxVTORequest(
prompt=prompt,
person=tensor_to_base64_string(person[:, :, :, :3]),
garment=tensor_to_base64_string(garment[:, :, :, :3]),
seed=seed,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2ProImageNode(IO.ComfyNode): class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode" NODE_ID = "Flux2ProImageNode"
@ -853,6 +1005,8 @@ class BFLExtension(ComfyExtension):
FluxKontextMaxImageNode, FluxKontextMaxImageNode,
FluxProExpandNode, FluxProExpandNode,
FluxProFillNode, FluxProFillNode,
FluxEraseNode,
FluxVTONode,
Flux2ProImageNode, Flux2ProImageNode,
Flux2MaxImageNode, Flux2MaxImageNode,
Flux2ImageNode, Flux2ImageNode,