From 80db9a8e25db1341e99e343baee0c741a869038c Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 4 Feb 2025 15:17:14 -0800 Subject: [PATCH] Florence2 --- .../language/transformers_model_management.py | 15 +- comfy/model_downloader.py | 1 + comfy/utils.py | 7 +- comfy_extras/nodes/nodes_florence2.py | 171 ++++++++++++++++++ 4 files changed, 189 insertions(+), 5 deletions(-) create mode 100644 comfy_extras/nodes/nodes_florence2.py diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index 688cb6f81..e3980c254 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -5,7 +5,6 @@ import inspect import logging import operator import pathlib -import warnings from functools import reduce from typing import Optional, Any, Callable @@ -24,8 +23,13 @@ from ..component_model.tensor_types import RGBImageBatch from ..model_downloader import get_or_download_huggingface_repo from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu from ..model_management_types import ModelManageable -from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block, tensor2pil +from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block +# tweaks to support florence 2 +_OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.keys()) + ['florence2'] + +# should be added if the expectation is that this model emits special tokens +_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2'} class TransformersManagedModel(ModelManageable, LanguageModel): def __init__( @@ -46,6 +50,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel): self.offload_device = unet_offload_device() self._config_dict = config_dict self._on_set_processor(self._processor) + self._model_type = "" if model.device != self.offload_device: model.to(device=self.offload_device) @@ -94,7 +99,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel): model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props) elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props) - elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props) else: model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props) @@ -139,6 +144,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel): processor=processor ) + model_managed._model_type = model_type + return model_managed def generate(self, tokens: TOKENS_TYPE = None, @@ -229,7 +236,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel): prev_src_lang = None # todo: is this redundant consider I'm decoding in the on_finalized_text block? try: - outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=self._model_type not in _DO_NOT_SKIP_SPECIAL_TOKENS, clean_up_tokenization_spaces=False) finally: if prev_src_lang is not None: tokenizer.src_lang = prev_src_lang diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 3b8810eb8..ca3b28d13 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -451,6 +451,7 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = { 'THUDM/chatglm3-6b', 'roborovski/superprompt-v1', 'Qwen/Qwen2-VL-7B-Instruct', + 'microsoft/Florence-2-large-ft', } KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([ diff --git a/comfy/utils.py b/comfy/utils.py index 0fcb88a71..d3293c5cf 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -70,7 +70,6 @@ else: logging.debug("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") - # deprecate PROGRESS_BAR_ENABLED def _get_progress_bar_enabled(): warnings.warn( @@ -1230,6 +1229,12 @@ def tensor2pil(t_image: torch.Tensor) -> Image: return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) +def pil2mask(image): + image_np = np.array(image.convert("L")).astype(np.float32) / 255.0 + mask = torch.from_numpy(image_np) + return 1.0 - mask + + def reshape_mask(input_mask, output_shape): dims = len(output_shape) - 2 diff --git a/comfy_extras/nodes/nodes_florence2.py b/comfy_extras/nodes/nodes_florence2.py new file mode 100644 index 000000000..93d99a932 --- /dev/null +++ b/comfy_extras/nodes/nodes_florence2.py @@ -0,0 +1,171 @@ +from typing import List, Union, Optional + +import numpy as np +import torch +from PIL import Image, ImageDraw +from typing_extensions import TypedDict, NotRequired + +from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch +from comfy.language.language_types import TOKENS_TYPE_NAME, LanguageModel +from comfy.language.transformers_model_management import TransformersManagedModel +from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult +from comfy.utils import pil2mask + +TASKS = ['