mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 18:27:40 +08:00
fix: use glob matching for Gemini image MIME types (#12511)
gemini-3-pro-image-preview nondeterministically returns image/jpeg instead of image/png. get_image_from_response() hardcoded get_parts_by_type(response, "image/png"), silently dropping JPEG responses and falling back to torch.zeros (all-black output). Add _mime_matches() helper using fnmatch for glob-style MIME matching. Change get_image_from_response() to request "image/*" so any image format returned by the API is correctly captured.
This commit is contained in:
parent
8ad38d2073
commit
83dd65f23a
@ -6,6 +6,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
|||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from fnmatch import fnmatch
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
@ -119,6 +120,13 @@ async def create_image_parts(
|
|||||||
return image_parts
|
return image_parts
|
||||||
|
|
||||||
|
|
||||||
|
def _mime_matches(mime: GeminiMimeType | None, pattern: str) -> bool:
|
||||||
|
"""Check if a MIME type matches a pattern. Supports fnmatch globs (e.g. 'image/*')."""
|
||||||
|
if mime is None:
|
||||||
|
return False
|
||||||
|
return fnmatch(mime.value, pattern)
|
||||||
|
|
||||||
|
|
||||||
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
|
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
|
||||||
"""
|
"""
|
||||||
Filter response parts by their type.
|
Filter response parts by their type.
|
||||||
@ -151,9 +159,9 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
|||||||
for part in candidate.content.parts:
|
for part in candidate.content.parts:
|
||||||
if part_type == "text" and part.text:
|
if part_type == "text" and part.text:
|
||||||
parts.append(part)
|
parts.append(part)
|
||||||
elif part.inlineData and part.inlineData.mimeType == part_type:
|
elif part.inlineData and _mime_matches(part.inlineData.mimeType, part_type):
|
||||||
parts.append(part)
|
parts.append(part)
|
||||||
elif part.fileData and part.fileData.mimeType == part_type:
|
elif part.fileData and _mime_matches(part.fileData.mimeType, part_type):
|
||||||
parts.append(part)
|
parts.append(part)
|
||||||
|
|
||||||
if not parts and blocked_reasons:
|
if not parts and blocked_reasons:
|
||||||
@ -178,7 +186,7 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
|||||||
|
|
||||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||||
image_tensors: list[Input.Image] = []
|
image_tensors: list[Input.Image] = []
|
||||||
parts = get_parts_by_type(response, "image/png")
|
parts = get_parts_by_type(response, "image/*")
|
||||||
for part in parts:
|
for part in parts:
|
||||||
if part.inlineData:
|
if part.inlineData:
|
||||||
image_data = base64.b64decode(part.inlineData.data)
|
image_data = base64.b64decode(part.inlineData.data)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user