fix: use glob matching for Gemini image MIME types (#12511)

gemini-3-pro-image-preview nondeterministically returns image/jpeg
instead of image/png. get_image_from_response() hardcoded
get_parts_by_type(response, "image/png"), silently dropping JPEG
responses and falling back to torch.zeros (all-black output).

Add _mime_matches() helper using fnmatch for glob-style MIME matching.
Change get_image_from_response() to request "image/*" so any image
format returned by the API is correctly captured.
This commit is contained in:
Hunter 2026-02-18 00:03:54 -05:00 committed by comfyanonymous
parent 19236edfa4
commit 185c61dc26

View File

@ -6,6 +6,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
import base64 import base64
import os import os
from enum import Enum from enum import Enum
from fnmatch import fnmatch
from io import BytesIO from io import BytesIO
from typing import Literal from typing import Literal
@ -119,6 +120,13 @@ async def create_image_parts(
return image_parts return image_parts
def _mime_matches(mime: GeminiMimeType | None, pattern: str) -> bool:
"""Check if a MIME type matches a pattern. Supports fnmatch globs (e.g. 'image/*')."""
if mime is None:
return False
return fnmatch(mime.value, pattern)
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]: def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
""" """
Filter response parts by their type. Filter response parts by their type.
@ -151,9 +159,9 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
for part in candidate.content.parts: for part in candidate.content.parts:
if part_type == "text" and part.text: if part_type == "text" and part.text:
parts.append(part) parts.append(part)
elif part.inlineData and part.inlineData.mimeType == part_type: elif part.inlineData and _mime_matches(part.inlineData.mimeType, part_type):
parts.append(part) parts.append(part)
elif part.fileData and part.fileData.mimeType == part_type: elif part.fileData and _mime_matches(part.fileData.mimeType, part_type):
parts.append(part) parts.append(part)
if not parts and blocked_reasons: if not parts and blocked_reasons:
@ -178,7 +186,7 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = [] image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png") parts = get_parts_by_type(response, "image/*")
for part in parts: for part in parts:
if part.inlineData: if part.inlineData:
image_data = base64.b64decode(part.inlineData.data) image_data = base64.b64decode(part.inlineData.data)