diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index 688cb6f81..e3980c254 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -5,7 +5,6 @@ import inspect import logging import operator import pathlib -import warnings from functools import reduce from typing import Optional, Any, Callable @@ -24,8 +23,13 @@ from ..component_model.tensor_types import RGBImageBatch from ..model_downloader import get_or_download_huggingface_repo from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu from ..model_management_types import ModelManageable -from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block, tensor2pil +from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block +# tweaks to support florence 2 +_OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.keys()) + ['florence2'] + +# should be added if the expectation is that this model emits special tokens +_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2'} class TransformersManagedModel(ModelManageable, LanguageModel): def __init__( @@ -46,6 +50,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel): self.offload_device = unet_offload_device() self._config_dict = config_dict self._on_set_processor(self._processor) + self._model_type = "" if model.device != self.offload_device: model.to(device=self.offload_device) @@ -94,7 +99,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel): model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props) elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props) - elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props) else: model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props) @@ -139,6 +144,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel): processor=processor ) + model_managed._model_type = model_type + return model_managed def generate(self, tokens: TOKENS_TYPE = None, @@ -229,7 +236,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel): prev_src_lang = None # todo: is this redundant consider I'm decoding in the on_finalized_text block? try: - outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=self._model_type not in _DO_NOT_SKIP_SPECIAL_TOKENS, clean_up_tokenization_spaces=False) finally: if prev_src_lang is not None: tokenizer.src_lang = prev_src_lang diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 3b8810eb8..ca3b28d13 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -451,6 +451,7 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = { 'THUDM/chatglm3-6b', 'roborovski/superprompt-v1', 'Qwen/Qwen2-VL-7B-Instruct', + 'microsoft/Florence-2-large-ft', } KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([ diff --git a/comfy/utils.py b/comfy/utils.py index 0fcb88a71..d3293c5cf 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -70,7 +70,6 @@ else: logging.debug("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") - # deprecate PROGRESS_BAR_ENABLED def _get_progress_bar_enabled(): warnings.warn( @@ -1230,6 +1229,12 @@ def tensor2pil(t_image: torch.Tensor) -> Image: return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) +def pil2mask(image): + image_np = np.array(image.convert("L")).astype(np.float32) / 255.0 + mask = torch.from_numpy(image_np) + return 1.0 - mask + + def reshape_mask(input_mask, output_shape): dims = len(output_shape) - 2 diff --git a/comfy_extras/nodes/nodes_florence2.py b/comfy_extras/nodes/nodes_florence2.py new file mode 100644 index 000000000..93d99a932 --- /dev/null +++ b/comfy_extras/nodes/nodes_florence2.py @@ -0,0 +1,171 @@ +from typing import List, Union, Optional + +import numpy as np +import torch +from PIL import Image, ImageDraw +from typing_extensions import TypedDict, NotRequired + +from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch +from comfy.language.language_types import TOKENS_TYPE_NAME, LanguageModel +from comfy.language.transformers_model_management import TransformersManagedModel +from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult +from comfy.utils import pil2mask + +TASKS = ['', '', '', '', '', '', '', '', '', '', '', '', '', ''] +TASKS_TYPE_NAME = "FLORENCE2_TASK" +FLORENCE2_OUTPUT_TYPE_NAME = "FLORENCE2_OUTPUT" + + +class BoundingBoxResult(TypedDict): + bboxes: List[List[float]] # List of [x1, y1, x2, y2] coordinates + labels: List[str] + scores: Optional[List[float]] # Only present if score mode is used + + +class QuadBoxResult(TypedDict): + quad_boxes: List[List[float]] # List of [x1, y1, x2, y2, x3, y3, x4, y4] coordinates + labels: List[str] + + +class PolygonResult(TypedDict): + polygons: List[List[float]] # List of [x1, y1, x2, y2, ...] coordinates + labels: List[str] + + +class BBoxesAndPolygonsResult(TypedDict): + bboxes: List[List[float]] + bboxes_labels: List[str] + polygons: List[List[float]] + polygons_labels: List[str] + + +PostProcessResult = TypedDict('PostProcessResult', { + '': NotRequired[Union[str, QuadBoxResult]], # pure_text or ocr + '': NotRequired[QuadBoxResult], # ocr + '': NotRequired[str], # pure_text + '': NotRequired[str], # pure_text + '': NotRequired[str], # pure_text + '': NotRequired[BoundingBoxResult], # description_with_bboxes + '': NotRequired[BoundingBoxResult], # description_with_bboxes + '': NotRequired[BoundingBoxResult], # phrase_grounding + '': NotRequired[PolygonResult], # polygons + '': NotRequired[PolygonResult], # polygons + '': NotRequired[BBoxesAndPolygonsResult], # description_with_bboxes_or_polygons + '': NotRequired[str], # pure_text + '': NotRequired[str], # pure_text + '': NotRequired[str], # pure_text + '': NotRequired[BoundingBoxResult] # bboxes +}) + + +def draw_polygons(image: Image, prediction: PolygonResult) -> Image: + """ + Draws segmentation masks with polygons on an image. + + Parameters: + - image_path: Path to the image file. + - prediction: Dictionary containing 'polygons' and 'labels' keys. + 'polygons' is a list of lists, each containing vertices of a polygon. + 'labels' is a list of labels corresponding to each polygon. + - fill_mask: Boolean indicating whether to fill the polygons with color. + """ + # Load the image + + draw = ImageDraw.Draw(image) + + # Set up scale factor if needed (use 1 if not scaling) + scale = 1 + + # Iterate over polygons and labels + for polygons, label in zip(prediction['polygons'], prediction['labels']): + for _polygon in polygons: + _polygon = np.array(_polygon).reshape(-1, 2) + if len(_polygon) < 3: + print('Invalid polygon:', _polygon) + continue + + _polygon = (_polygon * scale).reshape(-1).tolist() + + # Draw the polygon + draw.polygon(_polygon, fill='white') + return image + + +class Florence2TaskTokenize(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "model": ("MODEL",), + "prompt": ("STRING", {"default": "", "multiline": True}), + "task": (TASKS, {"default": TASKS[0]}) + }, + "optional": { + "images": ("IMAGE", {}), + } + } + + CATEGORY = "language" + RETURN_TYPES = (TOKENS_TYPE_NAME, TASKS_TYPE_NAME) + RETURN_NAMES = ("tokens",) + FUNCTION = "execute" + + def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, task: str = "") -> ValidatedNodeResult: + return model.tokenize(prompt, images, task + prompt), + + +class Florence2PostProcess(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "model": ("MODEL",), + "generated_text": ("STRING", {"forceInput": True}), + "task": (TASKS, {"default": TASKS[0]}) + }, + "optional": { + "images": ("IMAGE", {}), + } + } + + CATEGORY = "language" + RETURN_TYPES = (FLORENCE2_OUTPUT_TYPE_NAME,) + RETURN_NAMES = ("florence2 output",) + FUNCTION = "execute" + + def execute(self, model: TransformersManagedModel, generated_text: str = "", task: str = "", images: RGBImageBatch = None) -> tuple[PostProcessResult]: + assert hasattr(model.processor, "post_process_generation") + return model.processor.post_process_generation(generated_text, task=task, image_size=(images.shape[-2], images.shape[-3])), + + +class Florence2OutputToPolygon(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "florence2_output": (FLORENCE2_OUTPUT_TYPE_NAME, {}), + }, + "optional": { + "images": ("IMAGE", {}), + } + } + + CATEGORY = "language" + RETURN_TYPES = ("MASK",) + FUNCTION = "execute" + + def execute(self, florence2_output: PostProcessResult, images: RGBImageBatch = None) -> tuple[MaskBatch]: + image = Image.new('RGB', (images.shape[-2], images.shape[-3]), color='black') + for prediction in ('', '', ''): + if prediction in florence2_output: + image = draw_polygons(image, florence2_output[prediction]) + return pil2mask(image), + + +NODE_CLASS_MAPPINGS = {} +for cls in ( + Florence2PostProcess, + Florence2TaskTokenize, + Florence2OutputToPolygon +): + NODE_CLASS_MAPPINGS[cls.__name__] = cls \ No newline at end of file