diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py
index 688cb6f81..e3980c254 100644
--- a/comfy/language/transformers_model_management.py
+++ b/comfy/language/transformers_model_management.py
@@ -5,7 +5,6 @@ import inspect
import logging
import operator
import pathlib
-import warnings
from functools import reduce
from typing import Optional, Any, Callable
@@ -24,8 +23,13 @@ from ..component_model.tensor_types import RGBImageBatch
from ..model_downloader import get_or_download_huggingface_repo
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
from ..model_management_types import ModelManageable
-from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block, tensor2pil
+from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
+# tweaks to support florence 2
+_OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.keys()) + ['florence2']
+
+# should be added if the expectation is that this model emits special tokens
+_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2'}
class TransformersManagedModel(ModelManageable, LanguageModel):
def __init__(
@@ -46,6 +50,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
self.offload_device = unet_offload_device()
self._config_dict = config_dict
self._on_set_processor(self._processor)
+ self._model_type = ""
if model.device != self.offload_device:
model.to(device=self.offload_device)
@@ -94,7 +99,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
- elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
+ elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
else:
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
@@ -139,6 +144,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
processor=processor
)
+ model_managed._model_type = model_type
+
return model_managed
def generate(self, tokens: TOKENS_TYPE = None,
@@ -229,7 +236,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
prev_src_lang = None
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
try:
- outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=self._model_type not in _DO_NOT_SKIP_SPECIAL_TOKENS, clean_up_tokenization_spaces=False)
finally:
if prev_src_lang is not None:
tokenizer.src_lang = prev_src_lang
diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py
index 3b8810eb8..ca3b28d13 100644
--- a/comfy/model_downloader.py
+++ b/comfy/model_downloader.py
@@ -451,6 +451,7 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
'THUDM/chatglm3-6b',
'roborovski/superprompt-v1',
'Qwen/Qwen2-VL-7B-Instruct',
+ 'microsoft/Florence-2-large-ft',
}
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
diff --git a/comfy/utils.py b/comfy/utils.py
index 0fcb88a71..d3293c5cf 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -70,7 +70,6 @@ else:
logging.debug("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
-
# deprecate PROGRESS_BAR_ENABLED
def _get_progress_bar_enabled():
warnings.warn(
@@ -1230,6 +1229,12 @@ def tensor2pil(t_image: torch.Tensor) -> Image:
return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
+def pil2mask(image):
+ image_np = np.array(image.convert("L")).astype(np.float32) / 255.0
+ mask = torch.from_numpy(image_np)
+ return 1.0 - mask
+
+
def reshape_mask(input_mask, output_shape):
dims = len(output_shape) - 2
diff --git a/comfy_extras/nodes/nodes_florence2.py b/comfy_extras/nodes/nodes_florence2.py
new file mode 100644
index 000000000..93d99a932
--- /dev/null
+++ b/comfy_extras/nodes/nodes_florence2.py
@@ -0,0 +1,171 @@
+from typing import List, Union, Optional
+
+import numpy as np
+import torch
+from PIL import Image, ImageDraw
+from typing_extensions import TypedDict, NotRequired
+
+from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
+from comfy.language.language_types import TOKENS_TYPE_NAME, LanguageModel
+from comfy.language.transformers_model_management import TransformersManagedModel
+from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
+from comfy.utils import pil2mask
+
+TASKS = ['
', '', '', '', '', '', '', '', '', '', '', '', '', '']
+TASKS_TYPE_NAME = "FLORENCE2_TASK"
+FLORENCE2_OUTPUT_TYPE_NAME = "FLORENCE2_OUTPUT"
+
+
+class BoundingBoxResult(TypedDict):
+ bboxes: List[List[float]] # List of [x1, y1, x2, y2] coordinates
+ labels: List[str]
+ scores: Optional[List[float]] # Only present if score mode is used
+
+
+class QuadBoxResult(TypedDict):
+ quad_boxes: List[List[float]] # List of [x1, y1, x2, y2, x3, y3, x4, y4] coordinates
+ labels: List[str]
+
+
+class PolygonResult(TypedDict):
+ polygons: List[List[float]] # List of [x1, y1, x2, y2, ...] coordinates
+ labels: List[str]
+
+
+class BBoxesAndPolygonsResult(TypedDict):
+ bboxes: List[List[float]]
+ bboxes_labels: List[str]
+ polygons: List[List[float]]
+ polygons_labels: List[str]
+
+
+PostProcessResult = TypedDict('PostProcessResult', {
+ '': NotRequired[Union[str, QuadBoxResult]], # pure_text or ocr
+ '': NotRequired[QuadBoxResult], # ocr
+ '': NotRequired[str], # pure_text
+ '': NotRequired[str], # pure_text
+ '': NotRequired[str], # pure_text
+ '': NotRequired[BoundingBoxResult], # description_with_bboxes
+ '': NotRequired[BoundingBoxResult], # description_with_bboxes
+ '': NotRequired[BoundingBoxResult], # phrase_grounding
+ '': NotRequired[PolygonResult], # polygons
+ '': NotRequired[PolygonResult], # polygons
+ '': NotRequired[BBoxesAndPolygonsResult], # description_with_bboxes_or_polygons
+ '': NotRequired[str], # pure_text
+ '': NotRequired[str], # pure_text
+ '': NotRequired[str], # pure_text
+ '': NotRequired[BoundingBoxResult] # bboxes
+})
+
+
+def draw_polygons(image: Image, prediction: PolygonResult) -> Image:
+ """
+ Draws segmentation masks with polygons on an image.
+
+ Parameters:
+ - image_path: Path to the image file.
+ - prediction: Dictionary containing 'polygons' and 'labels' keys.
+ 'polygons' is a list of lists, each containing vertices of a polygon.
+ 'labels' is a list of labels corresponding to each polygon.
+ - fill_mask: Boolean indicating whether to fill the polygons with color.
+ """
+ # Load the image
+
+ draw = ImageDraw.Draw(image)
+
+ # Set up scale factor if needed (use 1 if not scaling)
+ scale = 1
+
+ # Iterate over polygons and labels
+ for polygons, label in zip(prediction['polygons'], prediction['labels']):
+ for _polygon in polygons:
+ _polygon = np.array(_polygon).reshape(-1, 2)
+ if len(_polygon) < 3:
+ print('Invalid polygon:', _polygon)
+ continue
+
+ _polygon = (_polygon * scale).reshape(-1).tolist()
+
+ # Draw the polygon
+ draw.polygon(_polygon, fill='white')
+ return image
+
+
+class Florence2TaskTokenize(CustomNode):
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypes:
+ return {
+ "required": {
+ "model": ("MODEL",),
+ "prompt": ("STRING", {"default": "", "multiline": True}),
+ "task": (TASKS, {"default": TASKS[0]})
+ },
+ "optional": {
+ "images": ("IMAGE", {}),
+ }
+ }
+
+ CATEGORY = "language"
+ RETURN_TYPES = (TOKENS_TYPE_NAME, TASKS_TYPE_NAME)
+ RETURN_NAMES = ("tokens",)
+ FUNCTION = "execute"
+
+ def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, task: str = "") -> ValidatedNodeResult:
+ return model.tokenize(prompt, images, task + prompt),
+
+
+class Florence2PostProcess(CustomNode):
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypes:
+ return {
+ "required": {
+ "model": ("MODEL",),
+ "generated_text": ("STRING", {"forceInput": True}),
+ "task": (TASKS, {"default": TASKS[0]})
+ },
+ "optional": {
+ "images": ("IMAGE", {}),
+ }
+ }
+
+ CATEGORY = "language"
+ RETURN_TYPES = (FLORENCE2_OUTPUT_TYPE_NAME,)
+ RETURN_NAMES = ("florence2 output",)
+ FUNCTION = "execute"
+
+ def execute(self, model: TransformersManagedModel, generated_text: str = "", task: str = "", images: RGBImageBatch = None) -> tuple[PostProcessResult]:
+ assert hasattr(model.processor, "post_process_generation")
+ return model.processor.post_process_generation(generated_text, task=task, image_size=(images.shape[-2], images.shape[-3])),
+
+
+class Florence2OutputToPolygon(CustomNode):
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypes:
+ return {
+ "required": {
+ "florence2_output": (FLORENCE2_OUTPUT_TYPE_NAME, {}),
+ },
+ "optional": {
+ "images": ("IMAGE", {}),
+ }
+ }
+
+ CATEGORY = "language"
+ RETURN_TYPES = ("MASK",)
+ FUNCTION = "execute"
+
+ def execute(self, florence2_output: PostProcessResult, images: RGBImageBatch = None) -> tuple[MaskBatch]:
+ image = Image.new('RGB', (images.shape[-2], images.shape[-3]), color='black')
+ for prediction in ('', '', ''):
+ if prediction in florence2_output:
+ image = draw_polygons(image, florence2_output[prediction])
+ return pil2mask(image),
+
+
+NODE_CLASS_MAPPINGS = {}
+for cls in (
+ Florence2PostProcess,
+ Florence2TaskTokenize,
+ Florence2OutputToPolygon
+):
+ NODE_CLASS_MAPPINGS[cls.__name__] = cls
\ No newline at end of file