from typing import List, Union, Optional
import numpy as np
import torch
from PIL import Image, ImageDraw
from typing_extensions import TypedDict, NotRequired
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
from comfy.language.language_types import TOKENS_TYPE_NAME, LanguageModel
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.utils import pil2mask
TASKS = ['
', '', '', '', '', '', '', '', '', '', '', '', '', '']
TASKS_TYPE_NAME = "FLORENCE2_TASK"
FLORENCE2_OUTPUT_TYPE_NAME = "FLORENCE2_OUTPUT"
class BoundingBoxResult(TypedDict):
bboxes: List[List[float]] # List of [x1, y1, x2, y2] coordinates
labels: List[str]
scores: Optional[List[float]] # Only present if score mode is used
class QuadBoxResult(TypedDict):
quad_boxes: List[List[float]] # List of [x1, y1, x2, y2, x3, y3, x4, y4] coordinates
labels: List[str]
class PolygonResult(TypedDict):
polygons: List[List[float]] # List of [x1, y1, x2, y2, ...] coordinates
labels: List[str]
class BBoxesAndPolygonsResult(TypedDict):
bboxes: List[List[float]]
bboxes_labels: List[str]
polygons: List[List[float]]
polygons_labels: List[str]
PostProcessResult = TypedDict('PostProcessResult', {
'': NotRequired[Union[str, QuadBoxResult]], # pure_text or ocr
'': NotRequired[QuadBoxResult], # ocr
'': NotRequired[str], # pure_text
'': NotRequired[str], # pure_text
'': NotRequired[str], # pure_text
'': NotRequired[BoundingBoxResult], # description_with_bboxes
'': NotRequired[BoundingBoxResult], # description_with_bboxes
'': NotRequired[BoundingBoxResult], # phrase_grounding
'': NotRequired[PolygonResult], # polygons
'': NotRequired[PolygonResult], # polygons
'': NotRequired[BBoxesAndPolygonsResult], # description_with_bboxes_or_polygons
'': NotRequired[str], # pure_text
'': NotRequired[str], # pure_text
'': NotRequired[str], # pure_text
'': NotRequired[BoundingBoxResult] # bboxes
})
def draw_polygons(image: Image, prediction: PolygonResult) -> Image:
"""
Draws segmentation masks with polygons on an image.
Parameters:
- image_path: Path to the image file.
- prediction: Dictionary containing 'polygons' and 'labels' keys.
'polygons' is a list of lists, each containing vertices of a polygon.
'labels' is a list of labels corresponding to each polygon.
- fill_mask: Boolean indicating whether to fill the polygons with color.
"""
# Load the image
draw = ImageDraw.Draw(image)
# Set up scale factor if needed (use 1 if not scaling)
scale = 1
# Iterate over polygons and labels
for polygons, label in zip(prediction['polygons'], prediction['labels']):
for _polygon in polygons:
_polygon = np.array(_polygon).reshape(-1, 2)
if len(_polygon) < 3:
print('Invalid polygon:', _polygon)
continue
_polygon = (_polygon * scale).reshape(-1).tolist()
# Draw the polygon
draw.polygon(_polygon, fill='white')
return image
class Florence2TaskTokenize(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
"task": (TASKS, {"default": TASKS[0]})
},
"optional": {
"images": ("IMAGE", {}),
}
}
CATEGORY = "language"
RETURN_TYPES = (TOKENS_TYPE_NAME, TASKS_TYPE_NAME)
RETURN_NAMES = ("tokens",)
FUNCTION = "execute"
def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, task: str = "") -> ValidatedNodeResult:
return model.tokenize(prompt, images, task + prompt),
class Florence2PostProcess(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"generated_text": ("STRING", {"forceInput": True}),
"task": (TASKS, {"default": TASKS[0]})
},
"optional": {
"images": ("IMAGE", {}),
}
}
CATEGORY = "language"
RETURN_TYPES = (FLORENCE2_OUTPUT_TYPE_NAME,)
RETURN_NAMES = ("florence2 output",)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, generated_text: str = "", task: str = "", images: RGBImageBatch = None) -> tuple[PostProcessResult]:
assert hasattr(model.processor, "post_process_generation")
return model.processor.post_process_generation(generated_text, task=task, image_size=(images.shape[-2], images.shape[-3])),
class Florence2OutputToMask(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"florence2_output": (FLORENCE2_OUTPUT_TYPE_NAME, {}),
},
"optional": {
"images": ("IMAGE", {}),
}
}
CATEGORY = "language"
RETURN_TYPES = ("MASK",)
FUNCTION = "execute"
def execute(self, florence2_output: PostProcessResult, images: RGBImageBatch = None) -> tuple[MaskBatch]:
image = Image.new('RGB', (images.shape[-2], images.shape[-3]), color='black')
for prediction in ('', '', ''):
if prediction in florence2_output:
image = draw_polygons(image, florence2_output[prediction])
return pil2mask(image),
NODE_CLASS_MAPPINGS = {}
for cls in (
Florence2PostProcess,
Florence2TaskTokenize,
Florence2OutputToMask
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls