Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-11-24 22:10:49 +03:00 committed by GitHub
commit f0426a04cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 135 additions and 22 deletions

View File

@ -179,7 +179,10 @@ class Chroma(nn.Module):
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if i not in self.skip_mmdit: if i not in self.skip_mmdit:
double_mod = ( double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i), self.get_modulations(mod_vectors, "double_img", idx=i),
@ -222,7 +225,10 @@ class Chroma(nn.Module):
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit: if i not in self.skip_dit:
single_mod = self.get_modulations(mod_vectors, "single", idx=i) single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:

View File

@ -389,7 +389,10 @@ class HunyuanVideo(nn.Module):
attn_mask = None attn_mask = None
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -411,7 +414,10 @@ class HunyuanVideo(nn.Module):
img = torch.cat((img, txt), 1) img = torch.cat((img, txt), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}

View File

@ -439,7 +439,10 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from fractions import Fraction
from typing import Optional, Union, IO from typing import Optional, Union, IO
import io import io
import av import av
@ -72,6 +73,33 @@ class VideoInput(ABC):
frame_count = components.images.shape[0] frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate) return float(frame_count / components.frame_rate)
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video.
Default implementation uses :meth:`get_components`, which may require
loading all frames into memory. File-based implementations should
override this method and use container/stream metadata instead.
Returns:
Total number of frames as an integer.
"""
return int(self.get_components().images.shape[0])
def get_frame_rate(self) -> Fraction:
"""
Returns the frame rate of the video.
Default implementation materializes the video into memory via
`get_components()`. Subclasses that can inspect the underlying
container (e.g. `VideoFromFile`) should override this with a more
efficient implementation.
Returns:
Frame rate as a Fraction.
"""
return self.get_components().frame_rate
def get_container_format(self) -> str: def get_container_format(self) -> str:
""" """
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').

View File

@ -121,6 +121,71 @@ class VideoFromFile(VideoInput):
raise ValueError(f"Could not determine duration for file '{self.__file}'") raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video without materializing them as
torch tensors.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
duration_seconds = float(video_stream.duration * video_stream.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
return frame_count
def get_frame_rate(self) -> Fraction:
"""
Returns the average frame rate of the video using container metadata
without decoding all frames.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
if video_stream.average_rate:
return Fraction(video_stream.average_rate)
# Fallback: estimate from frames + duration if available
if video_stream.frames and container.duration:
duration_seconds = float(container.duration / av.time_base)
if duration_seconds > 0:
return Fraction(video_stream.frames / duration_seconds).limit_denominator()
# Last resort: match get_components_internal default
return Fraction(1)
def get_container_format(self) -> str: def get_container_format(self) -> str:
""" """
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
@ -238,6 +303,13 @@ class VideoFromFile(VideoInput):
packet.stream = stream_map[packet.stream] packet.stream = stream_map[packet.stream]
output_container.mux(packet) output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream
class VideoFromComponents(VideoInput): class VideoFromComponents(VideoInput):
""" """
Class representing video input from tensors. Class representing video input from tensors.

View File

@ -113,9 +113,9 @@ class GeminiGenerationConfig(BaseModel):
maxOutputTokens: int | None = Field(None, ge=16, le=8192) maxOutputTokens: int | None = Field(None, ge=16, le=8192)
seed: int | None = Field(None) seed: int | None = Field(None)
stopSequences: list[str] | None = Field(None) stopSequences: list[str] | None = Field(None)
temperature: float | None = Field(1, ge=0.0, le=2.0) temperature: float | None = Field(None, ge=0.0, le=2.0)
topK: int | None = Field(40, ge=1) topK: int | None = Field(None, ge=1)
topP: float | None = Field(0.95, ge=0.0, le=1.0) topP: float | None = Field(None, ge=0.0, le=1.0)
class GeminiImageConfig(BaseModel): class GeminiImageConfig(BaseModel):

View File

@ -104,14 +104,14 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
List of response parts matching the requested type. List of response parts matching the requested type.
""" """
if response.candidates is None: if response.candidates is None:
if response.promptFeedback.blockReason: if response.promptFeedback and response.promptFeedback.blockReason:
feedback = response.promptFeedback feedback = response.promptFeedback
raise ValueError( raise ValueError(
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})" f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
) )
raise NotImplementedError( raise ValueError(
"Gemini returned no response candidates. " "Gemini API returned no response candidates. If you are using the `IMAGE` modality, "
"Please report to ComfyUI repository with the example of workflow to reproduce this." "try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed."
) )
parts = [] parts = []
for part in response.candidates[0].content.parts: for part in response.candidates[0].content.parts:
@ -182,11 +182,12 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
else: else:
return None return None
final_price = response.usageMetadata.promptTokenCount * input_tokens_price final_price = response.usageMetadata.promptTokenCount * input_tokens_price
for i in response.usageMetadata.candidatesTokensDetails: if response.usageMetadata.candidatesTokensDetails:
if i.modality == Modality.IMAGE: for i in response.usageMetadata.candidatesTokensDetails:
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models if i.modality == Modality.IMAGE:
else: final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
final_price += output_text_tokens_price * i.tokenCount else:
final_price += output_text_tokens_price * i.tokenCount
if response.usageMetadata.thoughtsTokenCount: if response.usageMetadata.thoughtsTokenCount:
final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount
return final_price / 1_000_000.0 return final_price / 1_000_000.0
@ -645,7 +646,7 @@ class GeminiImage2(IO.ComfyNode):
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
default="auto", default="auto",
tooltip="If set to 'auto', matches your input image's aspect ratio; " tooltip="If set to 'auto', matches your input image's aspect ratio; "
"if no image is provided, generates a 1:1 square.", "if no image is provided, a 16:9 square is usually generated.",
), ),
IO.Combo.Input( IO.Combo.Input(
"resolution", "resolution",

View File

@ -5,8 +5,7 @@ import aiohttp
import torch import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.input.video_types import VideoInput from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis import topaz_api from comfy_api_nodes.apis import topaz_api
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
@ -282,7 +281,7 @@ class TopazVideoEnhance(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
video: VideoInput, video: Input.Video,
upscaler_enabled: bool, upscaler_enabled: bool,
upscaler_model: str, upscaler_model: str,
upscaler_resolution: str, upscaler_resolution: str,
@ -297,12 +296,10 @@ class TopazVideoEnhance(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if upscaler_enabled is False and interpolation_enabled is False: if upscaler_enabled is False and interpolation_enabled is False:
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
src_width, src_height = video.get_dimensions()
video_components = video.get_components()
src_frame_rate = int(video_components.frame_rate)
duration_sec = video.get_duration()
estimated_frames = int(duration_sec * src_frame_rate)
validate_container_format_is_mp4(video) validate_container_format_is_mp4(video)
src_width, src_height = video.get_dimensions()
src_frame_rate = int(video.get_frame_rate())
duration_sec = video.get_duration()
src_video_stream = video.get_stream_source() src_video_stream = video.get_stream_source()
target_width = src_width target_width = src_width
target_height = src_height target_height = src_height
@ -338,7 +335,7 @@ class TopazVideoEnhance(IO.ComfyNode):
container="mp4", container="mp4",
size=get_fs_object_size(src_video_stream), size=get_fs_object_size(src_video_stream),
duration=int(duration_sec), duration=int(duration_sec),
frameCount=estimated_frames, frameCount=video.get_frame_count(),
frameRate=src_frame_rate, frameRate=src_frame_rate,
resolution=topaz_api.Resolution(width=src_width, height=src_height), resolution=topaz_api.Resolution(width=src_width, height=src_height),
), ),