This commit is contained in:
Nicolas Martel 2026-02-03 18:08:12 +01:00 committed by GitHub
commit 522f899847
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 68 additions and 6 deletions

View File

@ -5749,8 +5749,8 @@ class EasyInputMessage(BaseModel):
class GeminiContent(BaseModel): class GeminiContent(BaseModel):
parts: List[GeminiPart] parts: List[GeminiPart] = Field(default_factory=list)
role: Role1 = Field(..., examples=['user']) role: Optional[Role1] = Field(None, examples=['user'])
class GeminiGenerateContentRequest(BaseModel): class GeminiGenerateContentRequest(BaseModel):

View File

@ -75,7 +75,7 @@ class GeminiTextPart(BaseModel):
class GeminiContent(BaseModel): class GeminiContent(BaseModel):
parts: list[GeminiPart] = Field([]) parts: list[GeminiPart] = Field([])
role: GeminiRole = Field(..., examples=["user"]) role: GeminiRole | None = Field(None, examples=["user"])
class GeminiSystemInstructionContent(BaseModel): class GeminiSystemInstructionContent(BaseModel):

View File

@ -119,6 +119,45 @@ async def create_image_parts(
return image_parts return image_parts
def _summarize_gemini_response_issues(response: GeminiGenerateContentResponse) -> str:
details: list[str] = []
if response.promptFeedback and response.promptFeedback.blockReason:
msg = f"promptFeedback.blockReason={response.promptFeedback.blockReason}"
if response.promptFeedback.blockReasonMessage:
msg = f"{msg} ({response.promptFeedback.blockReasonMessage})"
details.append(msg)
finish_reasons = sorted(
{
candidate.finishReason
for candidate in (response.candidates or [])
if candidate.finishReason
}
)
if finish_reasons:
details.append(f"finishReason(s)={', '.join(finish_reasons)}")
safety_ratings: set[str] = set()
for candidate in response.candidates or []:
for rating in candidate.safetyRatings or []:
if rating.category and rating.probability:
safety_ratings.add(f"{rating.category}:{rating.probability}")
elif rating.category:
safety_ratings.add(str(rating.category))
if safety_ratings:
details.append(f"safetyRatings={', '.join(sorted(safety_ratings))}")
candidates = response.candidates or []
if candidates:
missing_content = sum(
1 for candidate in candidates if candidate.content is None or candidate.content.parts is None
)
if missing_content:
details.append(f"candidates_missing_content={missing_content}/{len(candidates)}")
return "; ".join(details)
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]: def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
""" """
Filter response parts by their type. Filter response parts by their type.
@ -156,8 +195,21 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
elif part.fileData and part.fileData.mimeType == part_type: elif part.fileData and part.fileData.mimeType == part_type:
parts.append(part) parts.append(part)
if not parts and blocked_reasons: if not parts:
raise ValueError(f"Gemini API blocked the request. Reasons: {blocked_reasons}") if blocked_reasons:
raise ValueError(f"Gemini API blocked the request. Reasons: {blocked_reasons}")
if part_type == "text":
return []
details = _summarize_gemini_response_issues(response)
if details:
raise ValueError(
f"Gemini API returned no {part_type} parts. Details: {details}. "
"If you are using the `IMAGE` modality, try `IMAGE+TEXT` to see why image generation failed."
)
raise ValueError(
f"Gemini API returned no {part_type} parts. "
"If you are using the `IMAGE` modality, try `IMAGE+TEXT` to see why image generation failed."
)
return parts return parts
@ -187,7 +239,17 @@ async def get_image_from_response(response: GeminiGenerateContentResponse) -> In
returned_image = await download_url_to_image_tensor(part.fileData.fileUri) returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
image_tensors.append(returned_image) image_tensors.append(returned_image)
if len(image_tensors) == 0: if len(image_tensors) == 0:
return torch.zeros((1, 1024, 1024, 4)) details = _summarize_gemini_response_issues(response)
if details:
raise ValueError(
"Gemini API returned no image parts. "
f"Details: {details}. "
"If you are using the `IMAGE` modality, try `IMAGE+TEXT` to see why image generation failed."
)
raise ValueError(
"Gemini API returned no image parts. "
"If you are using the `IMAGE` modality, try `IMAGE+TEXT` to see why image generation failed."
)
return torch.cat(image_tensors, dim=0) return torch.cat(image_tensors, dim=0)