mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 20:32:45 +08:00
Merge branch 'master' into automation/comfyui-frontend-bump
This commit is contained in:
commit
f44cf39b7f
@ -807,6 +807,7 @@ class VAE:
|
|||||||
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
|
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
|
||||||
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||||
elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio
|
elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio
|
||||||
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder."})
|
||||||
self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata)
|
self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata)
|
||||||
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
|
||||||
|
|||||||
@ -158,10 +158,17 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
|||||||
("Custom", None, None),
|
("Custom", None, None),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Seedance 2.0 reference video pixel count limits per model.
|
# Seedance 2.0 reference video pixel count limits per model and output resolution.
|
||||||
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
|
||||||
"dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408},
|
"dreamina-seedance-2-0-260128": {
|
||||||
"dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408},
|
"480p": {"min": 409_600, "max": 927_408},
|
||||||
|
"720p": {"min": 409_600, "max": 927_408},
|
||||||
|
"1080p": {"min": 409_600, "max": 2_073_600},
|
||||||
|
},
|
||||||
|
"dreamina-seedance-2-0-fast-260128": {
|
||||||
|
"480p": {"min": 409_600, "max": 927_408},
|
||||||
|
"720p": {"min": 409_600, "max": 927_408},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# The time in this dictionary are given for 10 seconds duration.
|
# The time in this dictionary are given for 10 seconds duration.
|
||||||
|
|||||||
@ -35,6 +35,7 @@ from comfy_api_nodes.util import (
|
|||||||
get_number_of_images,
|
get_number_of_images,
|
||||||
image_tensor_pair_to_batch,
|
image_tensor_pair_to_batch,
|
||||||
poll_op,
|
poll_op,
|
||||||
|
resize_video_to_pixel_budget,
|
||||||
sync_op,
|
sync_op,
|
||||||
upload_audio_to_comfyapi,
|
upload_audio_to_comfyapi,
|
||||||
upload_image_to_comfyapi,
|
upload_image_to_comfyapi,
|
||||||
@ -69,9 +70,12 @@ DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-2504
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None:
|
def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: str, index: int) -> None:
|
||||||
"""Validate reference video pixel count against Seedance 2.0 model limits."""
|
"""Validate reference video pixel count against Seedance 2.0 model limits for the selected resolution."""
|
||||||
limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
|
model_limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id)
|
||||||
|
if not model_limits:
|
||||||
|
return
|
||||||
|
limits = model_limits.get(resolution)
|
||||||
if not limits:
|
if not limits:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
@ -1373,6 +1377,14 @@ def _seedance2_reference_inputs(resolutions: list[str]):
|
|||||||
min=0,
|
min=0,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"auto_downscale",
|
||||||
|
default=False,
|
||||||
|
advanced=True,
|
||||||
|
optional=True,
|
||||||
|
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
|
||||||
|
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -1480,10 +1492,23 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
|||||||
|
|
||||||
model_id = SEEDANCE_MODELS[model["model"]]
|
model_id = SEEDANCE_MODELS[model["model"]]
|
||||||
has_video_input = len(reference_videos) > 0
|
has_video_input = len(reference_videos) > 0
|
||||||
|
|
||||||
|
if model.get("auto_downscale") and reference_videos:
|
||||||
|
max_px = (
|
||||||
|
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {})
|
||||||
|
.get(model["resolution"], {})
|
||||||
|
.get("max")
|
||||||
|
)
|
||||||
|
if max_px:
|
||||||
|
for key in reference_videos:
|
||||||
|
reference_videos[key] = resize_video_to_pixel_budget(
|
||||||
|
reference_videos[key], max_px
|
||||||
|
)
|
||||||
|
|
||||||
total_video_duration = 0.0
|
total_video_duration = 0.0
|
||||||
for i, key in enumerate(reference_videos, 1):
|
for i, key in enumerate(reference_videos, 1):
|
||||||
video = reference_videos[key]
|
video = reference_videos[key]
|
||||||
_validate_ref_video_pixels(video, model_id, i)
|
_validate_ref_video_pixels(video, model_id, model["resolution"], i)
|
||||||
try:
|
try:
|
||||||
dur = video.get_duration()
|
dur = video.get_duration()
|
||||||
if dur < 1.8:
|
if dur < 1.8:
|
||||||
|
|||||||
@ -24,8 +24,9 @@ from comfy_api_nodes.util import (
|
|||||||
AVERAGE_DURATION_VIDEO_GEN = 32
|
AVERAGE_DURATION_VIDEO_GEN = 32
|
||||||
MODELS_MAP = {
|
MODELS_MAP = {
|
||||||
"veo-2.0-generate-001": "veo-2.0-generate-001",
|
"veo-2.0-generate-001": "veo-2.0-generate-001",
|
||||||
"veo-3.1-generate": "veo-3.1-generate-preview",
|
"veo-3.1-generate": "veo-3.1-generate-001",
|
||||||
"veo-3.1-fast-generate": "veo-3.1-fast-generate-preview",
|
"veo-3.1-fast-generate": "veo-3.1-fast-generate-001",
|
||||||
|
"veo-3.1-lite": "veo-3.1-lite-generate-001",
|
||||||
"veo-3.0-generate-001": "veo-3.0-generate-001",
|
"veo-3.0-generate-001": "veo-3.0-generate-001",
|
||||||
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
|
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
|
||||||
}
|
}
|
||||||
@ -247,17 +248,8 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
|||||||
raise Exception("Video generation completed but no video was returned")
|
raise Exception("Video generation completed but no video was returned")
|
||||||
|
|
||||||
|
|
||||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
class Veo3VideoGenerationNode(IO.ComfyNode):
|
||||||
"""
|
"""Generates videos from text prompts using Google's Veo 3 API."""
|
||||||
Generates videos from text prompts using Google's Veo 3 API.
|
|
||||||
|
|
||||||
Supported models:
|
|
||||||
- veo-3.0-generate-001
|
|
||||||
- veo-3.0-fast-generate-001
|
|
||||||
|
|
||||||
This node extends the base Veo node with Veo 3 specific features including
|
|
||||||
audio generation and fixed 8-second duration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -279,6 +271,13 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
|||||||
default="16:9",
|
default="16:9",
|
||||||
tooltip="Aspect ratio of the output video",
|
tooltip="Aspect ratio of the output video",
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=["720p", "1080p", "4k"],
|
||||||
|
default="720p",
|
||||||
|
tooltip="Output video resolution. 4K is not available for veo-3.1-lite and veo-3.0 models.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
@ -289,11 +288,11 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
|||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"duration_seconds",
|
"duration_seconds",
|
||||||
default=8,
|
default=8,
|
||||||
min=8,
|
min=4,
|
||||||
max=8,
|
max=8,
|
||||||
step=1,
|
step=2,
|
||||||
display_mode=IO.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
tooltip="Duration of the output video in seconds",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
@ -332,10 +331,10 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
|||||||
options=[
|
options=[
|
||||||
"veo-3.1-generate",
|
"veo-3.1-generate",
|
||||||
"veo-3.1-fast-generate",
|
"veo-3.1-fast-generate",
|
||||||
|
"veo-3.1-lite",
|
||||||
"veo-3.0-generate-001",
|
"veo-3.0-generate-001",
|
||||||
"veo-3.0-fast-generate-001",
|
"veo-3.0-fast-generate-001",
|
||||||
],
|
],
|
||||||
default="veo-3.0-generate-001",
|
|
||||||
tooltip="Veo 3 model to use for video generation",
|
tooltip="Veo 3 model to use for video generation",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -356,21 +355,111 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]),
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "resolution", "duration_seconds"]),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
$m := widgets.model;
|
$m := widgets.model;
|
||||||
|
$r := widgets.resolution;
|
||||||
$a := widgets.generate_audio;
|
$a := widgets.generate_audio;
|
||||||
($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate"))
|
$seconds := widgets.duration_seconds;
|
||||||
? {"type":"usd","usd": ($a ? 1.2 : 0.8)}
|
$pps :=
|
||||||
: ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate"))
|
$contains($m, "lite")
|
||||||
? {"type":"usd","usd": ($a ? 3.2 : 1.6)}
|
? ($r = "1080p" ? ($a ? 0.08 : 0.05) : ($a ? 0.05 : 0.03))
|
||||||
: {"type":"range_usd","min_usd":0.8,"max_usd":3.2}
|
: $contains($m, "3.1-fast")
|
||||||
|
? ($r = "4k" ? ($a ? 0.30 : 0.25) : $r = "1080p" ? ($a ? 0.12 : 0.10) : ($a ? 0.10 : 0.08))
|
||||||
|
: $contains($m, "3.1-generate")
|
||||||
|
? ($r = "4k" ? ($a ? 0.60 : 0.40) : ($a ? 0.40 : 0.20))
|
||||||
|
: $contains($m, "3.0-fast")
|
||||||
|
? ($a ? 0.15 : 0.10)
|
||||||
|
: ($a ? 0.40 : 0.20);
|
||||||
|
{"type":"usd","usd": $pps * $seconds}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt,
|
||||||
|
aspect_ratio="16:9",
|
||||||
|
resolution="720p",
|
||||||
|
negative_prompt="",
|
||||||
|
duration_seconds=8,
|
||||||
|
enhance_prompt=True,
|
||||||
|
person_generation="ALLOW",
|
||||||
|
seed=0,
|
||||||
|
image=None,
|
||||||
|
model="veo-3.0-generate-001",
|
||||||
|
generate_audio=False,
|
||||||
|
):
|
||||||
|
if "lite" in model and resolution == "4k":
|
||||||
|
raise Exception("4K resolution is not supported by the veo-3.1-lite model.")
|
||||||
|
|
||||||
|
model = MODELS_MAP[model]
|
||||||
|
|
||||||
|
instances = [{"prompt": prompt}]
|
||||||
|
if image is not None:
|
||||||
|
image_base64 = tensor_to_base64_string(image)
|
||||||
|
if image_base64:
|
||||||
|
instances[0]["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
|
||||||
|
|
||||||
|
parameters = {
|
||||||
|
"aspectRatio": aspect_ratio,
|
||||||
|
"personGeneration": person_generation,
|
||||||
|
"durationSeconds": duration_seconds,
|
||||||
|
"enhancePrompt": True,
|
||||||
|
"generateAudio": generate_audio,
|
||||||
|
}
|
||||||
|
if negative_prompt:
|
||||||
|
parameters["negativePrompt"] = negative_prompt
|
||||||
|
if seed > 0:
|
||||||
|
parameters["seed"] = seed
|
||||||
|
if "veo-3.1" in model:
|
||||||
|
parameters["resolution"] = resolution
|
||||||
|
|
||||||
|
initial_response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
|
||||||
|
response_model=VeoGenVidResponse,
|
||||||
|
data=VeoGenVidRequest(
|
||||||
|
instances=instances,
|
||||||
|
parameters=parameters,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
poll_response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
|
||||||
|
response_model=VeoGenVidPollResponse,
|
||||||
|
status_extractor=lambda r: "completed" if r.done else "pending",
|
||||||
|
data=VeoGenVidPollRequest(operationName=initial_response.name),
|
||||||
|
poll_interval=9.0,
|
||||||
|
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
if poll_response.error:
|
||||||
|
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
|
||||||
|
|
||||||
|
response = poll_response.response
|
||||||
|
filtered_count = response.raiMediaFilteredCount
|
||||||
|
if filtered_count:
|
||||||
|
reasons = response.raiMediaFilteredReasons or []
|
||||||
|
reason_part = f": {reasons[0]}" if reasons else ""
|
||||||
|
raise Exception(
|
||||||
|
f"Content blocked by Google's Responsible AI filters{reason_part} "
|
||||||
|
f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.videos:
|
||||||
|
video = response.videos[0]
|
||||||
|
if video.bytesBase64Encoded:
|
||||||
|
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||||
|
if video.gcsUri:
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||||
|
raise Exception("Video returned but no data or URL was provided")
|
||||||
|
raise Exception("Video generation completed but no video was returned")
|
||||||
|
|
||||||
|
|
||||||
class Veo3FirstLastFrameNode(IO.ComfyNode):
|
class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||||
|
|
||||||
@ -394,7 +483,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
|||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||||
),
|
),
|
||||||
IO.Combo.Input("resolution", options=["720p", "1080p"]),
|
IO.Combo.Input("resolution", options=["720p", "1080p", "4k"]),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=["16:9", "9:16"],
|
options=["16:9", "9:16"],
|
||||||
@ -424,8 +513,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
|||||||
IO.Image.Input("last_frame", tooltip="End frame"),
|
IO.Image.Input("last_frame", tooltip="End frame"),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["veo-3.1-generate", "veo-3.1-fast-generate"],
|
options=["veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.1-lite"],
|
||||||
default="veo-3.1-fast-generate",
|
|
||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"generate_audio",
|
"generate_audio",
|
||||||
@ -443,26 +531,20 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]),
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration", "resolution"]),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
$prices := {
|
|
||||||
"veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 },
|
|
||||||
"veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 }
|
|
||||||
};
|
|
||||||
$m := widgets.model;
|
$m := widgets.model;
|
||||||
$ga := (widgets.generate_audio = "true");
|
$r := widgets.resolution;
|
||||||
|
$ga := widgets.generate_audio;
|
||||||
$seconds := widgets.duration;
|
$seconds := widgets.duration;
|
||||||
$modelKey :=
|
$pps :=
|
||||||
$contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" :
|
$contains($m, "lite")
|
||||||
$contains($m, "veo-3.1-generate") ? "veo-3.1-generate" :
|
? ($r = "1080p" ? ($ga ? 0.08 : 0.05) : ($ga ? 0.05 : 0.03))
|
||||||
"";
|
: $contains($m, "fast")
|
||||||
$audioKey := $ga ? "audio" : "no_audio";
|
? ($r = "4k" ? ($ga ? 0.30 : 0.25) : $r = "1080p" ? ($ga ? 0.12 : 0.10) : ($ga ? 0.10 : 0.08))
|
||||||
$modelPrices := $lookup($prices, $modelKey);
|
: ($r = "4k" ? ($ga ? 0.60 : 0.40) : ($ga ? 0.40 : 0.20));
|
||||||
$pps := $lookup($modelPrices, $audioKey);
|
{"type":"usd","usd": $pps * $seconds}
|
||||||
($pps != null)
|
|
||||||
? {"type":"usd","usd": $pps * $seconds}
|
|
||||||
: {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2}
|
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
@ -482,6 +564,9 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
|||||||
model: str,
|
model: str,
|
||||||
generate_audio: bool,
|
generate_audio: bool,
|
||||||
):
|
):
|
||||||
|
if "lite" in model and resolution == "4k":
|
||||||
|
raise Exception("4K resolution is not supported by the veo-3.1-lite model.")
|
||||||
|
|
||||||
model = MODELS_MAP[model]
|
model = MODELS_MAP[model]
|
||||||
initial_response = await sync_op(
|
initial_response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
@ -519,7 +604,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
|||||||
data=VeoGenVidPollRequest(
|
data=VeoGenVidPollRequest(
|
||||||
operationName=initial_response.name,
|
operationName=initial_response.name,
|
||||||
),
|
),
|
||||||
poll_interval=5.0,
|
poll_interval=9.0,
|
||||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from .conversions import (
|
|||||||
image_tensor_pair_to_batch,
|
image_tensor_pair_to_batch,
|
||||||
pil_to_bytesio,
|
pil_to_bytesio,
|
||||||
resize_mask_to_image,
|
resize_mask_to_image,
|
||||||
|
resize_video_to_pixel_budget,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
tensor_to_pil,
|
tensor_to_pil,
|
||||||
@ -90,6 +91,7 @@ __all__ = [
|
|||||||
"image_tensor_pair_to_batch",
|
"image_tensor_pair_to_batch",
|
||||||
"pil_to_bytesio",
|
"pil_to_bytesio",
|
||||||
"resize_mask_to_image",
|
"resize_mask_to_image",
|
||||||
|
"resize_video_to_pixel_budget",
|
||||||
"tensor_to_base64_string",
|
"tensor_to_base64_string",
|
||||||
"tensor_to_bytesio",
|
"tensor_to_bytesio",
|
||||||
"tensor_to_pil",
|
"tensor_to_pil",
|
||||||
|
|||||||
@ -129,22 +129,38 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
|||||||
return img_byte_arr
|
return img_byte_arr
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_downscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None:
|
||||||
|
"""Return downscaled (w, h) with even dims fitting ``total_pixels``, or None if already fits.
|
||||||
|
|
||||||
|
Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions
|
||||||
|
are rounded down to even values (many codecs require divisible-by-2).
|
||||||
|
"""
|
||||||
|
pixels = src_w * src_h
|
||||||
|
if pixels <= total_pixels:
|
||||||
|
return None
|
||||||
|
scale = math.sqrt(total_pixels / pixels)
|
||||||
|
new_w = max(2, int(src_w * scale))
|
||||||
|
new_h = max(2, int(src_h * scale))
|
||||||
|
new_w -= new_w % 2
|
||||||
|
new_h -= new_h % 2
|
||||||
|
return new_w, new_h
|
||||||
|
|
||||||
|
|
||||||
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
|
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
|
||||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
"""Downscale input image tensor to roughly the specified total pixels.
|
||||||
|
|
||||||
|
Output dimensions are rounded down to even values so that the result is guaranteed to fit within ``total_pixels``
|
||||||
|
and is compatible with codecs that require even dimensions (e.g. yuv420p).
|
||||||
|
"""
|
||||||
samples = image.movedim(-1, 1)
|
samples = image.movedim(-1, 1)
|
||||||
total = int(total_pixels)
|
dims = _compute_downscale_dims(samples.shape[3], samples.shape[2], int(total_pixels))
|
||||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
if dims is None:
|
||||||
if scale_by >= 1:
|
|
||||||
return image
|
return image
|
||||||
width = round(samples.shape[3] * scale_by)
|
new_w, new_h = dims
|
||||||
height = round(samples.shape[2] * scale_by)
|
return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1)
|
||||||
|
|
||||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
|
||||||
s = s.movedim(1, -1)
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor:
|
def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor:
|
||||||
"""Downscale input image tensor so the largest dimension is at most max_side pixels."""
|
"""Downscale input image tensor so the largest dimension is at most max_side pixels."""
|
||||||
samples = image.movedim(-1, 1)
|
samples = image.movedim(-1, 1)
|
||||||
height, width = samples.shape[2], samples.shape[3]
|
height, width = samples.shape[2], samples.shape[3]
|
||||||
@ -399,6 +415,72 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
|
|||||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video:
|
||||||
|
"""Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio.
|
||||||
|
|
||||||
|
Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio.
|
||||||
|
Aspect ratio is preserved up to a fraction of a percent (even-dim rounding).
|
||||||
|
"""
|
||||||
|
src_w, src_h = video.get_dimensions()
|
||||||
|
scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels)
|
||||||
|
if scale_dims is None:
|
||||||
|
return video
|
||||||
|
return _apply_video_scale(video, scale_dims)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input.Video:
|
||||||
|
"""Re-encode ``video`` scaled to ``scale_dims`` with a single decode/encode pass."""
|
||||||
|
out_w, out_h = scale_dims
|
||||||
|
output_buffer = BytesIO()
|
||||||
|
input_container = None
|
||||||
|
output_container = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
input_source = video.get_stream_source()
|
||||||
|
input_container = av.open(input_source, mode="r")
|
||||||
|
output_container = av.open(output_buffer, mode="w", format="mp4")
|
||||||
|
|
||||||
|
video_stream = output_container.add_stream("h264", rate=video.get_frame_rate())
|
||||||
|
video_stream.width = out_w
|
||||||
|
video_stream.height = out_h
|
||||||
|
video_stream.pix_fmt = "yuv420p"
|
||||||
|
|
||||||
|
audio_stream = None
|
||||||
|
for stream in input_container.streams:
|
||||||
|
if isinstance(stream, av.AudioStream):
|
||||||
|
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
|
||||||
|
audio_stream.sample_rate = stream.sample_rate
|
||||||
|
audio_stream.layout = stream.layout
|
||||||
|
break
|
||||||
|
|
||||||
|
for frame in input_container.decode(video=0):
|
||||||
|
frame = frame.reformat(width=out_w, height=out_h, format="yuv420p")
|
||||||
|
for packet in video_stream.encode(frame):
|
||||||
|
output_container.mux(packet)
|
||||||
|
for packet in video_stream.encode():
|
||||||
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
if audio_stream is not None:
|
||||||
|
input_container.seek(0)
|
||||||
|
for audio_frame in input_container.decode(audio=0):
|
||||||
|
for packet in audio_stream.encode(audio_frame):
|
||||||
|
output_container.mux(packet)
|
||||||
|
for packet in audio_stream.encode():
|
||||||
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
output_container.close()
|
||||||
|
input_container.close()
|
||||||
|
output_buffer.seek(0)
|
||||||
|
return InputImpl.VideoFromFile(output_buffer)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if input_container is not None:
|
||||||
|
input_container.close()
|
||||||
|
if output_container is not None:
|
||||||
|
output_container.close()
|
||||||
|
raise RuntimeError(f"Failed to resize video: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||||
if wav.dtype.is_floating_point:
|
if wav.dtype.is_floating_point:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user