Florence2

This commit is contained in:
doctorpangloss 2025-02-04 15:17:14 -08:00
parent ce3583ad42
commit 80db9a8e25
4 changed files with 189 additions and 5 deletions

View File

@ -5,7 +5,6 @@ import inspect
import logging import logging
import operator import operator
import pathlib import pathlib
import warnings
from functools import reduce from functools import reduce
from typing import Optional, Any, Callable from typing import Optional, Any, Callable
@ -24,8 +23,13 @@ from ..component_model.tensor_types import RGBImageBatch
from ..model_downloader import get_or_download_huggingface_repo from ..model_downloader import get_or_download_huggingface_repo
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
from ..model_management_types import ModelManageable from ..model_management_types import ModelManageable
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block, tensor2pil from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
# tweaks to support florence 2
_OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.keys()) + ['florence2']
# should be added if the expectation is that this model emits special tokens
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2'}
class TransformersManagedModel(ModelManageable, LanguageModel): class TransformersManagedModel(ModelManageable, LanguageModel):
def __init__( def __init__(
@ -46,6 +50,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
self.offload_device = unet_offload_device() self.offload_device = unet_offload_device()
self._config_dict = config_dict self._config_dict = config_dict
self._on_set_processor(self._processor) self._on_set_processor(self._processor)
self._model_type = ""
if model.device != self.offload_device: if model.device != self.offload_device:
model.to(device=self.offload_device) model.to(device=self.offload_device)
@ -94,7 +99,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props) model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props) model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props) model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
else: else:
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props) model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
@ -139,6 +144,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
processor=processor processor=processor
) )
model_managed._model_type = model_type
return model_managed return model_managed
def generate(self, tokens: TOKENS_TYPE = None, def generate(self, tokens: TOKENS_TYPE = None,
@ -229,7 +236,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
prev_src_lang = None prev_src_lang = None
# todo: is this redundant consider I'm decoding in the on_finalized_text block? # todo: is this redundant consider I'm decoding in the on_finalized_text block?
try: try:
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=self._model_type not in _DO_NOT_SKIP_SPECIAL_TOKENS, clean_up_tokenization_spaces=False)
finally: finally:
if prev_src_lang is not None: if prev_src_lang is not None:
tokenizer.src_lang = prev_src_lang tokenizer.src_lang = prev_src_lang

View File

@ -451,6 +451,7 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
'THUDM/chatglm3-6b', 'THUDM/chatglm3-6b',
'roborovski/superprompt-v1', 'roborovski/superprompt-v1',
'Qwen/Qwen2-VL-7B-Instruct', 'Qwen/Qwen2-VL-7B-Instruct',
'microsoft/Florence-2-large-ft',
} }
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([

View File

@ -70,7 +70,6 @@ else:
logging.debug("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") logging.debug("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
# deprecate PROGRESS_BAR_ENABLED # deprecate PROGRESS_BAR_ENABLED
def _get_progress_bar_enabled(): def _get_progress_bar_enabled():
warnings.warn( warnings.warn(
@ -1230,6 +1229,12 @@ def tensor2pil(t_image: torch.Tensor) -> Image:
return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
def pil2mask(image):
image_np = np.array(image.convert("L")).astype(np.float32) / 255.0
mask = torch.from_numpy(image_np)
return 1.0 - mask
def reshape_mask(input_mask, output_shape): def reshape_mask(input_mask, output_shape):
dims = len(output_shape) - 2 dims = len(output_shape) - 2

View File

@ -0,0 +1,171 @@
from typing import List, Union, Optional
import numpy as np
import torch
from PIL import Image, ImageDraw
from typing_extensions import TypedDict, NotRequired
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
from comfy.language.language_types import TOKENS_TYPE_NAME, LanguageModel
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.utils import pil2mask
TASKS = ['<CAPTION>', '<DETAILED_CAPTION>', '<MORE_DETAILED_CAPTION>', '<OD>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<CAPTION_TO_PHRASE_GROUNDING>', '<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>', '<REGION_TO_CATEGORY>', '<REGION_TO_DESCRIPTION>', '<OCR>', '<OCR_WITH_REGION>']
TASKS_TYPE_NAME = "FLORENCE2_TASK"
FLORENCE2_OUTPUT_TYPE_NAME = "FLORENCE2_OUTPUT"
class BoundingBoxResult(TypedDict):
bboxes: List[List[float]] # List of [x1, y1, x2, y2] coordinates
labels: List[str]
scores: Optional[List[float]] # Only present if score mode is used
class QuadBoxResult(TypedDict):
quad_boxes: List[List[float]] # List of [x1, y1, x2, y2, x3, y3, x4, y4] coordinates
labels: List[str]
class PolygonResult(TypedDict):
polygons: List[List[float]] # List of [x1, y1, x2, y2, ...] coordinates
labels: List[str]
class BBoxesAndPolygonsResult(TypedDict):
bboxes: List[List[float]]
bboxes_labels: List[str]
polygons: List[List[float]]
polygons_labels: List[str]
PostProcessResult = TypedDict('PostProcessResult', {
'<OCR>': NotRequired[Union[str, QuadBoxResult]], # pure_text or ocr
'<OCR_WITH_REGION>': NotRequired[QuadBoxResult], # ocr
'<CAPTION>': NotRequired[str], # pure_text
'<DETAILED_CAPTION>': NotRequired[str], # pure_text
'<MORE_DETAILED_CAPTION>': NotRequired[str], # pure_text
'<OD>': NotRequired[BoundingBoxResult], # description_with_bboxes
'<DENSE_REGION_CAPTION>': NotRequired[BoundingBoxResult], # description_with_bboxes
'<CAPTION_TO_PHRASE_GROUNDING>': NotRequired[BoundingBoxResult], # phrase_grounding
'<REFERRING_EXPRESSION_SEGMENTATION>': NotRequired[PolygonResult], # polygons
'<REGION_TO_SEGMENTATION>': NotRequired[PolygonResult], # polygons
'<OPEN_VOCABULARY_DETECTION>': NotRequired[BBoxesAndPolygonsResult], # description_with_bboxes_or_polygons
'<REGION_TO_CATEGORY>': NotRequired[str], # pure_text
'<REGION_TO_DESCRIPTION>': NotRequired[str], # pure_text
'<REGION_TO_OCR>': NotRequired[str], # pure_text
'<REGION_PROPOSAL>': NotRequired[BoundingBoxResult] # bboxes
})
def draw_polygons(image: Image, prediction: PolygonResult) -> Image:
"""
Draws segmentation masks with polygons on an image.
Parameters:
- image_path: Path to the image file.
- prediction: Dictionary containing 'polygons' and 'labels' keys.
'polygons' is a list of lists, each containing vertices of a polygon.
'labels' is a list of labels corresponding to each polygon.
- fill_mask: Boolean indicating whether to fill the polygons with color.
"""
# Load the image
draw = ImageDraw.Draw(image)
# Set up scale factor if needed (use 1 if not scaling)
scale = 1
# Iterate over polygons and labels
for polygons, label in zip(prediction['polygons'], prediction['labels']):
for _polygon in polygons:
_polygon = np.array(_polygon).reshape(-1, 2)
if len(_polygon) < 3:
print('Invalid polygon:', _polygon)
continue
_polygon = (_polygon * scale).reshape(-1).tolist()
# Draw the polygon
draw.polygon(_polygon, fill='white')
return image
class Florence2TaskTokenize(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"prompt": ("STRING", {"default": "", "multiline": True}),
"task": (TASKS, {"default": TASKS[0]})
},
"optional": {
"images": ("IMAGE", {}),
}
}
CATEGORY = "language"
RETURN_TYPES = (TOKENS_TYPE_NAME, TASKS_TYPE_NAME)
RETURN_NAMES = ("tokens",)
FUNCTION = "execute"
def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, task: str = "") -> ValidatedNodeResult:
return model.tokenize(prompt, images, task + prompt),
class Florence2PostProcess(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL",),
"generated_text": ("STRING", {"forceInput": True}),
"task": (TASKS, {"default": TASKS[0]})
},
"optional": {
"images": ("IMAGE", {}),
}
}
CATEGORY = "language"
RETURN_TYPES = (FLORENCE2_OUTPUT_TYPE_NAME,)
RETURN_NAMES = ("florence2 output",)
FUNCTION = "execute"
def execute(self, model: TransformersManagedModel, generated_text: str = "", task: str = "", images: RGBImageBatch = None) -> tuple[PostProcessResult]:
assert hasattr(model.processor, "post_process_generation")
return model.processor.post_process_generation(generated_text, task=task, image_size=(images.shape[-2], images.shape[-3])),
class Florence2OutputToPolygon(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"florence2_output": (FLORENCE2_OUTPUT_TYPE_NAME, {}),
},
"optional": {
"images": ("IMAGE", {}),
}
}
CATEGORY = "language"
RETURN_TYPES = ("MASK",)
FUNCTION = "execute"
def execute(self, florence2_output: PostProcessResult, images: RGBImageBatch = None) -> tuple[MaskBatch]:
image = Image.new('RGB', (images.shape[-2], images.shape[-3]), color='black')
for prediction in ('<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>'):
if prediction in florence2_output:
image = draw_polygons(image, florence2_output[prediction])
return pil2mask(image),
NODE_CLASS_MAPPINGS = {}
for cls in (
Florence2PostProcess,
Florence2TaskTokenize,
Florence2OutputToPolygon
):
NODE_CLASS_MAPPINGS[cls.__name__] = cls