This commit is contained in:
Alexander Piskun 2025-12-14 11:02:38 +01:00 committed by GitHub
commit 8890fc5999
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -34,6 +34,7 @@ from comfy_api_nodes.util import (
ApiEndpoint,
audio_to_base64_string,
bytesio_to_image_tensor,
download_url_to_image_tensor,
get_number_of_images,
sync_op,
tensor_to_base64_string,
@ -141,9 +142,11 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
)
parts = []
for part in response.candidates[0].content.parts:
if part_type == "text" and hasattr(part, "text") and part.text:
if part_type == "text" and part.text:
parts.append(part)
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
elif part.inlineData and part.inlineData.mimeType == part_type:
parts.append(part)
elif part.fileData and part.fileData.mimeType == part_type:
parts.append(part)
# Skip parts that don't match the requested type
return parts
@ -163,12 +166,15 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
else:
returned_image = await download_url_to_image_tensor(part.fileData.fileUri)
image_tensors.append(returned_image)
if len(image_tensors) == 0:
return torch.zeros((1, 1024, 1024, 4))
@ -596,7 +602,7 @@ class GeminiImage(IO.ComfyNode):
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
ApiEndpoint(path=f"/proxy/vertexai/gemini-image/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
@ -610,7 +616,7 @@ class GeminiImage(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiImage2(IO.ComfyNode):
@ -729,7 +735,7 @@ class GeminiImage2(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
ApiEndpoint(path=f"/proxy/vertexai/gemini-image/{model}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
@ -743,7 +749,7 @@ class GeminiImage2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
class GeminiExtension(ComfyExtension):