mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-02 12:27:59 +08:00
[Partner Nodes] feat: add Flux Virtual Try-On and Erase nodes (#14207)
This commit is contained in:
parent
462c27fdb2
commit
af58c5e674
@ -1,71 +1,71 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, confloat, conint
|
||||
|
||||
|
||||
class BFLOutputFormat(str, Enum):
|
||||
png = 'png'
|
||||
jpeg = 'jpeg'
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BFLFluxExpandImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image')
|
||||
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image')
|
||||
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image')
|
||||
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
|
||||
prompt: str = Field(...)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
top: int = Field(...)
|
||||
bottom: int = Field(...)
|
||||
left: int = Field(...)
|
||||
right: int = Field(...)
|
||||
steps: int = Field(...)
|
||||
guidance: float = Field(...)
|
||||
safety_tolerance: int = Field(6)
|
||||
output_format: str = Field("png")
|
||||
image: str = Field(None, description="A Base64-encoded string representing the image you wish to expand")
|
||||
|
||||
|
||||
class BFLFluxFillImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
prompt: str = Field(...)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
steps: int = Field(...)
|
||||
guidance: float = Field(...)
|
||||
safety_tolerance: int = Field(6)
|
||||
output_format: str = Field("png")
|
||||
image: str = Field(
|
||||
None, description="Base64-encoded string representing the image to modify. Can contain alpha mask if desired.",
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
mask: str = Field(
|
||||
None, description="Base64-encoded string representing the mask of the areas you wish to modify."
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
|
||||
|
||||
class BFLFluxEraseRequest(BaseModel):
|
||||
image: str = Field(..., description="A Base64-encoded string representing the image to erase from.")
|
||||
mask: str = Field(
|
||||
...,
|
||||
description="A Base64-encoded black/white mask matching the input dimensions; "
|
||||
"white (255) marks areas to remove, black (0) marks areas to preserve.",
|
||||
)
|
||||
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.')
|
||||
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
|
||||
dilate_pixels: int = Field(10)
|
||||
output_format: str = Field("png")
|
||||
|
||||
|
||||
class BFLFluxVTORequest(BaseModel):
|
||||
prompt: str = Field(
|
||||
..., description="Natural-language styling instruction. Required field, but may be an empty string."
|
||||
)
|
||||
person: str = Field(..., description="A Base64-encoded string representing the person image.")
|
||||
garment: str = Field(..., description="A Base64-encoded string representing the garment reference image.")
|
||||
seed: int | None = Field(None)
|
||||
safety_tolerance: int = Field(5)
|
||||
output_format: str = Field("png")
|
||||
|
||||
|
||||
class BFLFluxProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.')
|
||||
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
|
||||
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
|
||||
# None, description='Blend between the prompt and the image prompt.'
|
||||
# )
|
||||
prompt: str = Field(...)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
width: int = Field(1024, description="Must be a multiple of 32.")
|
||||
height: int = Field(768, description="Must be a multiple of 32.")
|
||||
safety_tolerance: int = Field(6)
|
||||
output_format: str = Field("png")
|
||||
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
|
||||
|
||||
|
||||
class Flux2ProGenerateRequest(BaseModel):
|
||||
@ -83,55 +83,37 @@ class Flux2ProGenerateRequest(BaseModel):
|
||||
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
safety_tolerance: int | None = Field(
|
||||
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
|
||||
)
|
||||
output_format: str | None = Field(
|
||||
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
|
||||
)
|
||||
safety_tolerance: int = Field(5)
|
||||
output_format: str = Field("png")
|
||||
|
||||
|
||||
class BFLFluxKontextProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
|
||||
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
|
||||
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
|
||||
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
prompt: str = Field(...)
|
||||
input_image: str | None = Field(None, description="Image to edit in base64 format")
|
||||
seed: int | None = Field(None)
|
||||
guidance: float = Field(...)
|
||||
steps: int = Field(...)
|
||||
safety_tolerance: int = Field(2)
|
||||
output_format: str = Field("png")
|
||||
aspect_ratio: str | None = Field(None)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
|
||||
|
||||
class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
|
||||
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
|
||||
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
|
||||
None, description='Blend between the prompt and the image prompt.'
|
||||
)
|
||||
prompt: str = Field(...)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
aspect_ratio: str | None = Field(None)
|
||||
safety_tolerance: int = Field(6)
|
||||
output_format: str = Field("png")
|
||||
raw: bool | None = Field(None)
|
||||
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
|
||||
image_prompt_strength: float | None = Field(None)
|
||||
|
||||
|
||||
class BFLFluxProGenerateResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
polling_url: str = Field(..., description="URL to poll for the generation result.")
|
||||
id: str = Field(...)
|
||||
polling_url: str = Field(...)
|
||||
cost: float | None = Field(None, description="Price in cents")
|
||||
|
||||
|
||||
@ -145,7 +127,7 @@ class BFLStatus(str, Enum):
|
||||
|
||||
|
||||
class BFLFluxStatusResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
status: BFLStatus = Field(..., description="The status of the task.")
|
||||
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
|
||||
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
|
||||
id: str = Field(...)
|
||||
status: BFLStatus = Field(...)
|
||||
result: dict[str, Any] | None = Field(None)
|
||||
progress: float | None = Field(None, ge=0.0, le=1.0)
|
||||
|
||||
@ -4,17 +4,20 @@ from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bfl import (
|
||||
BFLFluxEraseRequest,
|
||||
BFLFluxExpandImageRequest,
|
||||
BFLFluxFillImageRequest,
|
||||
BFLFluxKontextProGenerateRequest,
|
||||
BFLFluxProGenerateResponse,
|
||||
BFLFluxProUltraGenerateRequest,
|
||||
BFLFluxStatusResponse,
|
||||
BFLFluxVTORequest,
|
||||
BFLStatus,
|
||||
Flux2ProGenerateRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
convert_mask_to_image,
|
||||
download_url_to_image_tensor,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
@ -22,19 +25,11 @@ from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
validate_aspect_ratio_string,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: Input.Image):
|
||||
"""
|
||||
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
||||
"""
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = torch.cat([mask] * 3, dim=-1)
|
||||
return mask
|
||||
|
||||
|
||||
class FluxProUltraImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -519,6 +514,163 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class FluxEraseNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="FluxEraseNode",
|
||||
display_name="Flux Erase Image",
|
||||
category="image/partner/BFL",
|
||||
description="Removes the masked object from an image and reconstructs the background. "
|
||||
"Paint the mask over what you want to erase.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Mask.Input("mask", tooltip="White areas are removed; black areas are preserved."),
|
||||
IO.Int.Input(
|
||||
"dilate_pixels",
|
||||
default=10,
|
||||
min=0,
|
||||
max=25,
|
||||
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"range_usd","min_usd":0.03,"max_usd":0.06,"format":{"approximate":true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
mask: Input.Image,
|
||||
dilate_pixels: int = 10,
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_dimensions(image, min_width=256, min_height=256)
|
||||
mask = resize_mask_to_image(mask, image)
|
||||
mask = tensor_to_base64_string(convert_mask_to_image(mask))
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/erase-v1", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxEraseRequest(
|
||||
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
|
||||
mask=mask,
|
||||
dilate_pixels=dilate_pixels,
|
||||
),
|
||||
)
|
||||
|
||||
def price_extractor(_r: BaseModel) -> float | None:
|
||||
return None if initial_response.cost is None else initial_response.cost / 100
|
||||
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
price_extractor=price_extractor,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class FluxVTONode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="FluxVTONode",
|
||||
display_name="Flux Virtual Try-On",
|
||||
category="image/partner/BFL",
|
||||
description="Virtual try-on: dresses the person in the provided garment.",
|
||||
inputs=[
|
||||
IO.Image.Input("person", tooltip="Image of the person to dress."),
|
||||
IO.Image.Input("garment", tooltip="Image of the garment to apply."),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Optional natural-language styling instruction (e.g. how the garment should fit).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"range_usd","min_usd":0.0375,"max_usd":0.075,"format":{"approximate":true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
person: Input.Image,
|
||||
garment: Input.Image,
|
||||
prompt: str = "",
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/vto-v1", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxVTORequest(
|
||||
prompt=prompt,
|
||||
person=tensor_to_base64_string(person[:, :, :, :3]),
|
||||
garment=tensor_to_base64_string(garment[:, :, :, :3]),
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
|
||||
def price_extractor(_r: BaseModel) -> float | None:
|
||||
return None if initial_response.cost is None else initial_response.cost / 100
|
||||
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
price_extractor=price_extractor,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class Flux2ProImageNode(IO.ComfyNode):
|
||||
|
||||
NODE_ID = "Flux2ProImageNode"
|
||||
@ -853,6 +1005,8 @@ class BFLExtension(ComfyExtension):
|
||||
FluxKontextMaxImageNode,
|
||||
FluxProExpandNode,
|
||||
FluxProFillNode,
|
||||
FluxEraseNode,
|
||||
FluxVTONode,
|
||||
Flux2ProImageNode,
|
||||
Flux2MaxImageNode,
|
||||
Flux2ImageNode,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user