mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Florence2
This commit is contained in:
parent
ce3583ad42
commit
80db9a8e25
@ -5,7 +5,6 @@ import inspect
|
||||
import logging
|
||||
import operator
|
||||
import pathlib
|
||||
import warnings
|
||||
from functools import reduce
|
||||
from typing import Optional, Any, Callable
|
||||
|
||||
@ -24,8 +23,13 @@ from ..component_model.tensor_types import RGBImageBatch
|
||||
from ..model_downloader import get_or_download_huggingface_repo
|
||||
from ..model_management import unet_offload_device, get_torch_device, unet_dtype, load_models_gpu
|
||||
from ..model_management_types import ModelManageable
|
||||
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block, tensor2pil
|
||||
from ..utils import comfy_tqdm, ProgressBar, comfy_progress, seed_for_block
|
||||
|
||||
# tweaks to support florence 2
|
||||
_OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = list(MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.keys()) + ['florence2']
|
||||
|
||||
# should be added if the expectation is that this model emits special tokens
|
||||
_DO_NOT_SKIP_SPECIAL_TOKENS = {'florence2'}
|
||||
|
||||
class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
def __init__(
|
||||
@ -46,6 +50,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
self.offload_device = unet_offload_device()
|
||||
self._config_dict = config_dict
|
||||
self._on_set_processor(self._processor)
|
||||
self._model_type = ""
|
||||
if model.device != self.offload_device:
|
||||
model.to(device=self.offload_device)
|
||||
|
||||
@ -94,7 +99,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
model = AutoModelForVision2Seq.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
elif model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
elif model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
elif model_type in _OVERRIDDEN_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||
model = AutoModelForCausalLM.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
else:
|
||||
model = AutoModel.from_pretrained(**from_pretrained_kwargs, **props)
|
||||
@ -139,6 +144,8 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
processor=processor
|
||||
)
|
||||
|
||||
model_managed._model_type = model_type
|
||||
|
||||
return model_managed
|
||||
|
||||
def generate(self, tokens: TOKENS_TYPE = None,
|
||||
@ -229,7 +236,7 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
prev_src_lang = None
|
||||
# todo: is this redundant consider I'm decoding in the on_finalized_text block?
|
||||
try:
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=self._model_type not in _DO_NOT_SKIP_SPECIAL_TOKENS, clean_up_tokenization_spaces=False)
|
||||
finally:
|
||||
if prev_src_lang is not None:
|
||||
tokenizer.src_lang = prev_src_lang
|
||||
|
||||
@ -451,6 +451,7 @@ KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
|
||||
'THUDM/chatglm3-6b',
|
||||
'roborovski/superprompt-v1',
|
||||
'Qwen/Qwen2-VL-7B-Instruct',
|
||||
'microsoft/Florence-2-large-ft',
|
||||
}
|
||||
|
||||
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
|
||||
@ -70,7 +70,6 @@ else:
|
||||
logging.debug("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||
|
||||
|
||||
|
||||
# deprecate PROGRESS_BAR_ENABLED
|
||||
def _get_progress_bar_enabled():
|
||||
warnings.warn(
|
||||
@ -1230,6 +1229,12 @@ def tensor2pil(t_image: torch.Tensor) -> Image:
|
||||
return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
||||
|
||||
|
||||
def pil2mask(image):
|
||||
image_np = np.array(image.convert("L")).astype(np.float32) / 255.0
|
||||
mask = torch.from_numpy(image_np)
|
||||
return 1.0 - mask
|
||||
|
||||
|
||||
def reshape_mask(input_mask, output_shape):
|
||||
dims = len(output_shape) - 2
|
||||
|
||||
|
||||
171
comfy_extras/nodes/nodes_florence2.py
Normal file
171
comfy_extras/nodes/nodes_florence2.py
Normal file
@ -0,0 +1,171 @@
|
||||
from typing import List, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
from comfy.component_model.tensor_types import RGBImageBatch, MaskBatch
|
||||
from comfy.language.language_types import TOKENS_TYPE_NAME, LanguageModel
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||
from comfy.utils import pil2mask
|
||||
|
||||
TASKS = ['<CAPTION>', '<DETAILED_CAPTION>', '<MORE_DETAILED_CAPTION>', '<OD>', '<DENSE_REGION_CAPTION>', '<REGION_PROPOSAL>', '<CAPTION_TO_PHRASE_GROUNDING>', '<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>', '<REGION_TO_CATEGORY>', '<REGION_TO_DESCRIPTION>', '<OCR>', '<OCR_WITH_REGION>']
|
||||
TASKS_TYPE_NAME = "FLORENCE2_TASK"
|
||||
FLORENCE2_OUTPUT_TYPE_NAME = "FLORENCE2_OUTPUT"
|
||||
|
||||
|
||||
class BoundingBoxResult(TypedDict):
|
||||
bboxes: List[List[float]] # List of [x1, y1, x2, y2] coordinates
|
||||
labels: List[str]
|
||||
scores: Optional[List[float]] # Only present if score mode is used
|
||||
|
||||
|
||||
class QuadBoxResult(TypedDict):
|
||||
quad_boxes: List[List[float]] # List of [x1, y1, x2, y2, x3, y3, x4, y4] coordinates
|
||||
labels: List[str]
|
||||
|
||||
|
||||
class PolygonResult(TypedDict):
|
||||
polygons: List[List[float]] # List of [x1, y1, x2, y2, ...] coordinates
|
||||
labels: List[str]
|
||||
|
||||
|
||||
class BBoxesAndPolygonsResult(TypedDict):
|
||||
bboxes: List[List[float]]
|
||||
bboxes_labels: List[str]
|
||||
polygons: List[List[float]]
|
||||
polygons_labels: List[str]
|
||||
|
||||
|
||||
PostProcessResult = TypedDict('PostProcessResult', {
|
||||
'<OCR>': NotRequired[Union[str, QuadBoxResult]], # pure_text or ocr
|
||||
'<OCR_WITH_REGION>': NotRequired[QuadBoxResult], # ocr
|
||||
'<CAPTION>': NotRequired[str], # pure_text
|
||||
'<DETAILED_CAPTION>': NotRequired[str], # pure_text
|
||||
'<MORE_DETAILED_CAPTION>': NotRequired[str], # pure_text
|
||||
'<OD>': NotRequired[BoundingBoxResult], # description_with_bboxes
|
||||
'<DENSE_REGION_CAPTION>': NotRequired[BoundingBoxResult], # description_with_bboxes
|
||||
'<CAPTION_TO_PHRASE_GROUNDING>': NotRequired[BoundingBoxResult], # phrase_grounding
|
||||
'<REFERRING_EXPRESSION_SEGMENTATION>': NotRequired[PolygonResult], # polygons
|
||||
'<REGION_TO_SEGMENTATION>': NotRequired[PolygonResult], # polygons
|
||||
'<OPEN_VOCABULARY_DETECTION>': NotRequired[BBoxesAndPolygonsResult], # description_with_bboxes_or_polygons
|
||||
'<REGION_TO_CATEGORY>': NotRequired[str], # pure_text
|
||||
'<REGION_TO_DESCRIPTION>': NotRequired[str], # pure_text
|
||||
'<REGION_TO_OCR>': NotRequired[str], # pure_text
|
||||
'<REGION_PROPOSAL>': NotRequired[BoundingBoxResult] # bboxes
|
||||
})
|
||||
|
||||
|
||||
def draw_polygons(image: Image, prediction: PolygonResult) -> Image:
|
||||
"""
|
||||
Draws segmentation masks with polygons on an image.
|
||||
|
||||
Parameters:
|
||||
- image_path: Path to the image file.
|
||||
- prediction: Dictionary containing 'polygons' and 'labels' keys.
|
||||
'polygons' is a list of lists, each containing vertices of a polygon.
|
||||
'labels' is a list of labels corresponding to each polygon.
|
||||
- fill_mask: Boolean indicating whether to fill the polygons with color.
|
||||
"""
|
||||
# Load the image
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
# Set up scale factor if needed (use 1 if not scaling)
|
||||
scale = 1
|
||||
|
||||
# Iterate over polygons and labels
|
||||
for polygons, label in zip(prediction['polygons'], prediction['labels']):
|
||||
for _polygon in polygons:
|
||||
_polygon = np.array(_polygon).reshape(-1, 2)
|
||||
if len(_polygon) < 3:
|
||||
print('Invalid polygon:', _polygon)
|
||||
continue
|
||||
|
||||
_polygon = (_polygon * scale).reshape(-1).tolist()
|
||||
|
||||
# Draw the polygon
|
||||
draw.polygon(_polygon, fill='white')
|
||||
return image
|
||||
|
||||
|
||||
class Florence2TaskTokenize(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"prompt": ("STRING", {"default": "", "multiline": True}),
|
||||
"task": (TASKS, {"default": TASKS[0]})
|
||||
},
|
||||
"optional": {
|
||||
"images": ("IMAGE", {}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
RETURN_TYPES = (TOKENS_TYPE_NAME, TASKS_TYPE_NAME)
|
||||
RETURN_NAMES = ("tokens",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, model: LanguageModel, prompt: str, images: List[torch.Tensor] | torch.Tensor = None, task: str = "") -> ValidatedNodeResult:
|
||||
return model.tokenize(prompt, images, task + prompt),
|
||||
|
||||
|
||||
class Florence2PostProcess(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"generated_text": ("STRING", {"forceInput": True}),
|
||||
"task": (TASKS, {"default": TASKS[0]})
|
||||
},
|
||||
"optional": {
|
||||
"images": ("IMAGE", {}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
RETURN_TYPES = (FLORENCE2_OUTPUT_TYPE_NAME,)
|
||||
RETURN_NAMES = ("florence2 output",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, model: TransformersManagedModel, generated_text: str = "", task: str = "", images: RGBImageBatch = None) -> tuple[PostProcessResult]:
|
||||
assert hasattr(model.processor, "post_process_generation")
|
||||
return model.processor.post_process_generation(generated_text, task=task, image_size=(images.shape[-2], images.shape[-3])),
|
||||
|
||||
|
||||
class Florence2OutputToPolygon(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"florence2_output": (FLORENCE2_OUTPUT_TYPE_NAME, {}),
|
||||
},
|
||||
"optional": {
|
||||
"images": ("IMAGE", {}),
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "language"
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, florence2_output: PostProcessResult, images: RGBImageBatch = None) -> tuple[MaskBatch]:
|
||||
image = Image.new('RGB', (images.shape[-2], images.shape[-3]), color='black')
|
||||
for prediction in ('<REFERRING_EXPRESSION_SEGMENTATION>', '<REGION_TO_SEGMENTATION>', '<OPEN_VOCABULARY_DETECTION>'):
|
||||
if prediction in florence2_output:
|
||||
image = draw_polygons(image, florence2_output[prediction])
|
||||
return pil2mask(image),
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
for cls in (
|
||||
Florence2PostProcess,
|
||||
Florence2TaskTokenize,
|
||||
Florence2OutputToPolygon
|
||||
):
|
||||
NODE_CLASS_MAPPINGS[cls.__name__] = cls
|
||||
Loading…
Reference in New Issue
Block a user