mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 04:10:15 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
f0426a04cc
@ -179,7 +179,10 @@ class Chroma(nn.Module):
|
|||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if i not in self.skip_mmdit:
|
if i not in self.skip_mmdit:
|
||||||
double_mod = (
|
double_mod = (
|
||||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||||
@ -222,7 +225,10 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
|
transformer_options["block_type"] = "single"
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if i not in self.skip_dit:
|
if i not in self.skip_dit:
|
||||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
|
|||||||
@ -389,7 +389,10 @@ class HunyuanVideo(nn.Module):
|
|||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.double_blocks):
|
for i, block in enumerate(self.double_blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -411,7 +414,10 @@ class HunyuanVideo(nn.Module):
|
|||||||
|
|
||||||
img = torch.cat((img, txt), 1)
|
img = torch.cat((img, txt), 1)
|
||||||
|
|
||||||
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
|
transformer_options["block_type"] = "single"
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -439,7 +439,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
|
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from fractions import Fraction
|
||||||
from typing import Optional, Union, IO
|
from typing import Optional, Union, IO
|
||||||
import io
|
import io
|
||||||
import av
|
import av
|
||||||
@ -72,6 +73,33 @@ class VideoInput(ABC):
|
|||||||
frame_count = components.images.shape[0]
|
frame_count = components.images.shape[0]
|
||||||
return float(frame_count / components.frame_rate)
|
return float(frame_count / components.frame_rate)
|
||||||
|
|
||||||
|
def get_frame_count(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the number of frames in the video.
|
||||||
|
|
||||||
|
Default implementation uses :meth:`get_components`, which may require
|
||||||
|
loading all frames into memory. File-based implementations should
|
||||||
|
override this method and use container/stream metadata instead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total number of frames as an integer.
|
||||||
|
"""
|
||||||
|
return int(self.get_components().images.shape[0])
|
||||||
|
|
||||||
|
def get_frame_rate(self) -> Fraction:
|
||||||
|
"""
|
||||||
|
Returns the frame rate of the video.
|
||||||
|
|
||||||
|
Default implementation materializes the video into memory via
|
||||||
|
`get_components()`. Subclasses that can inspect the underlying
|
||||||
|
container (e.g. `VideoFromFile`) should override this with a more
|
||||||
|
efficient implementation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Frame rate as a Fraction.
|
||||||
|
"""
|
||||||
|
return self.get_components().frame_rate
|
||||||
|
|
||||||
def get_container_format(self) -> str:
|
def get_container_format(self) -> str:
|
||||||
"""
|
"""
|
||||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||||
|
|||||||
@ -121,6 +121,71 @@ class VideoFromFile(VideoInput):
|
|||||||
|
|
||||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||||
|
|
||||||
|
def get_frame_count(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the number of frames in the video without materializing them as
|
||||||
|
torch tensors.
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0)
|
||||||
|
|
||||||
|
with av.open(self.__file, mode="r") as container:
|
||||||
|
video_stream = self._get_first_video_stream(container)
|
||||||
|
# 1. Prefer the frames field if available
|
||||||
|
if video_stream.frames and video_stream.frames > 0:
|
||||||
|
return int(video_stream.frames)
|
||||||
|
|
||||||
|
# 2. Try to estimate from duration and average_rate using only metadata
|
||||||
|
if container.duration is not None and video_stream.average_rate:
|
||||||
|
duration_seconds = float(container.duration / av.time_base)
|
||||||
|
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||||
|
if estimated_frames > 0:
|
||||||
|
return estimated_frames
|
||||||
|
|
||||||
|
if (
|
||||||
|
getattr(video_stream, "duration", None) is not None
|
||||||
|
and getattr(video_stream, "time_base", None) is not None
|
||||||
|
and video_stream.average_rate
|
||||||
|
):
|
||||||
|
duration_seconds = float(video_stream.duration * video_stream.time_base)
|
||||||
|
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||||
|
if estimated_frames > 0:
|
||||||
|
return estimated_frames
|
||||||
|
|
||||||
|
# 3. Last resort: decode frames and count them (streaming)
|
||||||
|
frame_count = 0
|
||||||
|
container.seek(0)
|
||||||
|
for packet in container.demux(video_stream):
|
||||||
|
for _ in packet.decode():
|
||||||
|
frame_count += 1
|
||||||
|
|
||||||
|
if frame_count == 0:
|
||||||
|
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||||
|
return frame_count
|
||||||
|
|
||||||
|
def get_frame_rate(self) -> Fraction:
|
||||||
|
"""
|
||||||
|
Returns the average frame rate of the video using container metadata
|
||||||
|
without decoding all frames.
|
||||||
|
"""
|
||||||
|
if isinstance(self.__file, io.BytesIO):
|
||||||
|
self.__file.seek(0)
|
||||||
|
|
||||||
|
with av.open(self.__file, mode="r") as container:
|
||||||
|
video_stream = self._get_first_video_stream(container)
|
||||||
|
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
|
||||||
|
if video_stream.average_rate:
|
||||||
|
return Fraction(video_stream.average_rate)
|
||||||
|
|
||||||
|
# Fallback: estimate from frames + duration if available
|
||||||
|
if video_stream.frames and container.duration:
|
||||||
|
duration_seconds = float(container.duration / av.time_base)
|
||||||
|
if duration_seconds > 0:
|
||||||
|
return Fraction(video_stream.frames / duration_seconds).limit_denominator()
|
||||||
|
|
||||||
|
# Last resort: match get_components_internal default
|
||||||
|
return Fraction(1)
|
||||||
|
|
||||||
def get_container_format(self) -> str:
|
def get_container_format(self) -> str:
|
||||||
"""
|
"""
|
||||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||||
@ -238,6 +303,13 @@ class VideoFromFile(VideoInput):
|
|||||||
packet.stream = stream_map[packet.stream]
|
packet.stream = stream_map[packet.stream]
|
||||||
output_container.mux(packet)
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
def _get_first_video_stream(self, container: InputContainer):
|
||||||
|
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||||
|
if video_stream is None:
|
||||||
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
|
return video_stream
|
||||||
|
|
||||||
|
|
||||||
class VideoFromComponents(VideoInput):
|
class VideoFromComponents(VideoInput):
|
||||||
"""
|
"""
|
||||||
Class representing video input from tensors.
|
Class representing video input from tensors.
|
||||||
|
|||||||
@ -113,9 +113,9 @@ class GeminiGenerationConfig(BaseModel):
|
|||||||
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
||||||
seed: int | None = Field(None)
|
seed: int | None = Field(None)
|
||||||
stopSequences: list[str] | None = Field(None)
|
stopSequences: list[str] | None = Field(None)
|
||||||
temperature: float | None = Field(1, ge=0.0, le=2.0)
|
temperature: float | None = Field(None, ge=0.0, le=2.0)
|
||||||
topK: int | None = Field(40, ge=1)
|
topK: int | None = Field(None, ge=1)
|
||||||
topP: float | None = Field(0.95, ge=0.0, le=1.0)
|
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageConfig(BaseModel):
|
class GeminiImageConfig(BaseModel):
|
||||||
|
|||||||
@ -104,14 +104,14 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
|||||||
List of response parts matching the requested type.
|
List of response parts matching the requested type.
|
||||||
"""
|
"""
|
||||||
if response.candidates is None:
|
if response.candidates is None:
|
||||||
if response.promptFeedback.blockReason:
|
if response.promptFeedback and response.promptFeedback.blockReason:
|
||||||
feedback = response.promptFeedback
|
feedback = response.promptFeedback
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
|
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
|
||||||
)
|
)
|
||||||
raise NotImplementedError(
|
raise ValueError(
|
||||||
"Gemini returned no response candidates. "
|
"Gemini API returned no response candidates. If you are using the `IMAGE` modality, "
|
||||||
"Please report to ComfyUI repository with the example of workflow to reproduce this."
|
"try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed."
|
||||||
)
|
)
|
||||||
parts = []
|
parts = []
|
||||||
for part in response.candidates[0].content.parts:
|
for part in response.candidates[0].content.parts:
|
||||||
@ -182,11 +182,12 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
|
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
|
||||||
for i in response.usageMetadata.candidatesTokensDetails:
|
if response.usageMetadata.candidatesTokensDetails:
|
||||||
if i.modality == Modality.IMAGE:
|
for i in response.usageMetadata.candidatesTokensDetails:
|
||||||
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
|
if i.modality == Modality.IMAGE:
|
||||||
else:
|
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
|
||||||
final_price += output_text_tokens_price * i.tokenCount
|
else:
|
||||||
|
final_price += output_text_tokens_price * i.tokenCount
|
||||||
if response.usageMetadata.thoughtsTokenCount:
|
if response.usageMetadata.thoughtsTokenCount:
|
||||||
final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount
|
final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount
|
||||||
return final_price / 1_000_000.0
|
return final_price / 1_000_000.0
|
||||||
@ -645,7 +646,7 @@ class GeminiImage2(IO.ComfyNode):
|
|||||||
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
|
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
|
||||||
default="auto",
|
default="auto",
|
||||||
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
||||||
"if no image is provided, generates a 1:1 square.",
|
"if no image is provided, a 16:9 square is usually generated.",
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
|
|||||||
@ -5,8 +5,7 @@ import aiohttp
|
|||||||
import torch
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.input.video_types import VideoInput
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
|
||||||
from comfy_api_nodes.apis import topaz_api
|
from comfy_api_nodes.apis import topaz_api
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -282,7 +281,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
video: VideoInput,
|
video: Input.Video,
|
||||||
upscaler_enabled: bool,
|
upscaler_enabled: bool,
|
||||||
upscaler_model: str,
|
upscaler_model: str,
|
||||||
upscaler_resolution: str,
|
upscaler_resolution: str,
|
||||||
@ -297,12 +296,10 @@ class TopazVideoEnhance(IO.ComfyNode):
|
|||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if upscaler_enabled is False and interpolation_enabled is False:
|
if upscaler_enabled is False and interpolation_enabled is False:
|
||||||
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
|
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
|
||||||
src_width, src_height = video.get_dimensions()
|
|
||||||
video_components = video.get_components()
|
|
||||||
src_frame_rate = int(video_components.frame_rate)
|
|
||||||
duration_sec = video.get_duration()
|
|
||||||
estimated_frames = int(duration_sec * src_frame_rate)
|
|
||||||
validate_container_format_is_mp4(video)
|
validate_container_format_is_mp4(video)
|
||||||
|
src_width, src_height = video.get_dimensions()
|
||||||
|
src_frame_rate = int(video.get_frame_rate())
|
||||||
|
duration_sec = video.get_duration()
|
||||||
src_video_stream = video.get_stream_source()
|
src_video_stream = video.get_stream_source()
|
||||||
target_width = src_width
|
target_width = src_width
|
||||||
target_height = src_height
|
target_height = src_height
|
||||||
@ -338,7 +335,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
|||||||
container="mp4",
|
container="mp4",
|
||||||
size=get_fs_object_size(src_video_stream),
|
size=get_fs_object_size(src_video_stream),
|
||||||
duration=int(duration_sec),
|
duration=int(duration_sec),
|
||||||
frameCount=estimated_frames,
|
frameCount=video.get_frame_count(),
|
||||||
frameRate=src_frame_rate,
|
frameRate=src_frame_rate,
|
||||||
resolution=topaz_api.Resolution(width=src_width, height=src_height),
|
resolution=topaz_api.Resolution(width=src_width, height=src_height),
|
||||||
),
|
),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user