Compare commits

...

5 Commits

Author SHA1 Message Date
Alexander Piskun
0662021851
Merge 28db2757e1 into c4a14df9a3 2026-01-21 11:42:56 +08:00
Mylo
c4a14df9a3
Dynamically detect chroma radiance patch size (#11991)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-20 18:46:11 -05:00
Ivan Zorin
965d0ed509
fix: remove normalization of audio in LTX Mel spectrogram creation (#11990)
For LTX Audio VAE, remove normalization of audio during MEL spectrogram creation.
This aligs inference with training and prevents loud audio from being attenuated.
2026-01-20 18:44:28 -05:00
Alexander Piskun
ddc541ffda
feat(api-nodes): add WaveSpeed nodes (#11945) 2026-01-20 13:05:40 -08:00
bigcat88
28db2757e1
fix VideoFromFile stream source to _ReentrantBytesIO for parallel async use 2025-12-21 18:26:49 +02:00
5 changed files with 361 additions and 44 deletions

View File

@ -103,20 +103,10 @@ class AudioPreprocessor:
return waveform
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
@staticmethod
def normalize_amplitude(
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
) -> torch.Tensor:
waveform = waveform - waveform.mean(dim=2, keepdim=True)
peak = torch.max(torch.abs(waveform)) + eps
scale = peak.clamp(max=max_amplitude) / peak
return waveform * scale
def waveform_to_mel(
self, waveform: torch.Tensor, waveform_sample_rate: int, device
) -> torch.Tensor:
waveform = self.resample(waveform, waveform_sample_rate)
waveform = self.normalize_amplitude(waveform)
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=self.target_sample_rate,

View File

@ -253,7 +253,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
dit_config["patch_size"] = 16
dit_config["patch_size"] = state_dict.get('{}img_in_patch.weight'.format(key_prefix)).size(dim=-1)
dit_config["nerf_hidden_size"] = 64
dit_config["nerf_mlp_ratio"] = 4
dit_config["nerf_depth"] = 4

View File

@ -13,6 +13,124 @@ import torch
from .._util import VideoContainer, VideoCodec, VideoComponents
class _ReentrantBytesIO(io.BytesIO):
"""Read-only, seekable BytesIO-compatible view over shared immutable bytes."""
def __init__(self, data: bytes):
super().__init__(b"") # Initialize base BytesIO with an empty buffer; we do not use its internal storage.
if data is None:
raise TypeError("data must be bytes, not None")
self._data = data
self._pos = 0
self._len = len(data)
def getvalue(self) -> bytes:
if self.closed:
raise ValueError("I/O operation on closed file.")
return self._data
def getbuffer(self) -> memoryview:
if self.closed:
raise ValueError("I/O operation on closed file.")
return memoryview(self._data)
def readable(self) -> bool:
return True
def writable(self) -> bool:
return False
def seekable(self) -> bool:
return True
def tell(self) -> int:
return self._pos
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if self.closed:
raise ValueError("I/O operation on closed file.")
if whence == io.SEEK_SET:
new_pos = offset
elif whence == io.SEEK_CUR:
new_pos = self._pos + offset
elif whence == io.SEEK_END:
new_pos = self._len + offset
else:
raise ValueError(f"Invalid whence: {whence}")
if new_pos < 0:
raise ValueError("Negative seek position")
self._pos = new_pos
return self._pos
def readinto(self, b) -> int:
if self.closed:
raise ValueError("I/O operation on closed file.")
mv = memoryview(b)
if mv.readonly:
raise TypeError("readinto() argument must be writable")
mv = mv.cast("B")
if self._pos >= self._len:
return 0
n = min(len(mv), self._len - self._pos)
mv[:n] = self._data[self._pos:self._pos + n]
self._pos += n
return n
def readinto1(self, b) -> int:
return self.readinto(b)
def read(self, size: int = -1) -> bytes:
if self.closed:
raise ValueError("I/O operation on closed file.")
if size is None or size < 0:
size = self._len - self._pos
if self._pos >= self._len:
return b""
end = min(self._pos + size, self._len)
out = self._data[self._pos:end]
self._pos = end
return out
def read1(self, size: int = -1) -> bytes:
return self.read(size)
def readline(self, size: int = -1) -> bytes:
if self.closed:
raise ValueError("I/O operation on closed file.")
if self._pos >= self._len:
return b""
end_limit = self._len if size is None or size < 0 else min(self._len, self._pos + size)
nl = self._data.find(b"\n", self._pos, end_limit)
end = (nl + 1) if nl != -1 else end_limit
out = self._data[self._pos:end]
self._pos = end
return out
def readlines(self, hint: int = -1) -> list[bytes]:
if self.closed:
raise ValueError("I/O operation on closed file.")
lines: list[bytes] = []
total = 0
while True:
line = self.readline()
if not line:
break
lines.append(line)
total += len(line)
if hint is not None and 0 <= hint <= total:
break
return lines
def write(self, b) -> int:
raise io.UnsupportedOperation("not writable")
def writelines(self, lines) -> None:
raise io.UnsupportedOperation("not writable")
def truncate(self, size: int | None = None) -> int:
raise io.UnsupportedOperation("not writable")
def container_to_output_format(container_format: str | None) -> str | None:
"""
A container's `format` may be a comma-separated list of formats.
@ -57,21 +175,31 @@ class VideoFromFile(VideoInput):
Class representing video input from a file.
"""
def __init__(self, file: str | io.BytesIO):
__data: str | bytes
def __init__(self, file: str | io.BytesIO | bytes | bytearray | memoryview):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
if isinstance(file, str):
self.__data = file
elif isinstance(file, io.BytesIO):
# Snapshot to immutable bytes once to ensure re-entrant, parallel-safe readers.
self.__data = file.getbuffer().tobytes()
elif isinstance(file, (bytes, bytearray, memoryview)):
self.__data = bytes(file)
else:
raise TypeError(f"Unsupported video source type: {type(file)!r}")
def get_stream_source(self) -> str | io.BytesIO:
"""
Return the underlying file source for efficient streaming.
This avoids unnecessary memory copies when the source is already a file path.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
return self.__file
if isinstance(self.__data, str):
return self.__data
return _ReentrantBytesIO(self.__data)
def get_dimensions(self) -> tuple[int, int]:
"""
@ -80,14 +208,12 @@ class VideoFromFile(VideoInput):
Returns:
Tuple of (width, height)
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
with av.open(self.get_stream_source(), mode="r") as container:
for stream in container.streams:
if stream.type == 'video':
assert isinstance(stream, av.VideoStream)
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
raise ValueError(f"No video stream found in {self._source_label()}")
def get_duration(self) -> float:
"""
@ -96,9 +222,7 @@ class VideoFromFile(VideoInput):
Returns:
Duration in seconds
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
with av.open(self.get_stream_source(), mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
@ -119,17 +243,14 @@ class VideoFromFile(VideoInput):
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
raise ValueError(f"Could not determine duration for file '{self._source_label()}'")
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video without materializing them as
torch tensors.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
with av.open(self.get_stream_source(), mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
@ -160,7 +281,7 @@ class VideoFromFile(VideoInput):
frame_count += 1
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
raise ValueError(f"Could not determine frame count for file '{self._source_label()}'")
return frame_count
def get_frame_rate(self) -> Fraction:
@ -168,10 +289,7 @@ class VideoFromFile(VideoInput):
Returns the average frame rate of the video using container metadata
without decoding all frames.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
with av.open(self.get_stream_source(), mode="r") as container:
video_stream = self._get_first_video_stream(container)
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
if video_stream.average_rate:
@ -193,9 +311,7 @@ class VideoFromFile(VideoInput):
Returns:
Container format as string
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode='r') as container:
with av.open(self.get_stream_source(), mode='r') as container:
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
@ -239,11 +355,8 @@ class VideoFromFile(VideoInput):
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
with av.open(self.get_stream_source(), mode='r') as container:
return self.get_components_internal(container)
raise ValueError(f"No video stream found in file '{self.__file}'")
def save_to(
self,
@ -252,9 +365,7 @@ class VideoFromFile(VideoInput):
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
with av.open(self.get_stream_source(), mode='r') as container:
container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
reuse_streams = True
@ -306,9 +417,12 @@ class VideoFromFile(VideoInput):
def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
raise ValueError(f"No video stream found in file '{self._source_label()}'")
return video_stream
def _source_label(self) -> str:
return self.__data if isinstance(self.__data, str) else f"<in-memory video: {len(self.__data)} bytes>"
class VideoFromComponents(VideoInput):
"""

View File

@ -0,0 +1,35 @@
from pydantic import BaseModel, Field
class SeedVR2ImageRequest(BaseModel):
image: str = Field(...)
target_resolution: str = Field(...)
output_format: str = Field("png")
enable_sync_mode: bool = Field(False)
class FlashVSRRequest(BaseModel):
target_resolution: str = Field(...)
video: str = Field(...)
duration: float = Field(...)
class TaskCreatedDataResponse(BaseModel):
id: str = Field(...)
class TaskCreatedResponse(BaseModel):
code: int = Field(...)
message: str = Field(...)
data: TaskCreatedDataResponse | None = Field(None)
class TaskResultDataResponse(BaseModel):
status: str = Field(...)
outputs: list[str] = Field([])
class TaskResultResponse(BaseModel):
code: int = Field(...)
message: str = Field(...)
data: TaskResultDataResponse | None = Field(None)

View File

@ -0,0 +1,178 @@
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.wavespeed import (
FlashVSRRequest,
TaskCreatedResponse,
TaskResultResponse,
SeedVR2ImageRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output,
poll_op,
sync_op,
upload_video_to_comfyapi,
validate_container_format_is_mp4,
validate_video_duration,
upload_images_to_comfyapi,
get_number_of_images,
download_url_to_image_tensor,
)
class WavespeedFlashVSRNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="WavespeedFlashVSRNode",
display_name="FlashVSR Video Upscale",
category="api node/video/WaveSpeed",
description="Fast, high-quality video upscaler that "
"boosts resolution and restores clarity for low-resolution or blurry footage.",
inputs=[
IO.Video.Input("video"),
IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]),
expr="""
(
$price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032};
{
"type":"usd",
"usd": $lookup($price_for_1sec, widgets.target_resolution),
"format":{"suffix": "/second", "approximate": true}
}
)
""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
target_resolution: str,
) -> IO.NodeOutput:
validate_container_format_is_mp4(video)
validate_video_duration(video, min_duration=5, max_duration=60 * 10)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"),
response_model=TaskCreatedResponse,
data=FlashVSRRequest(
target_resolution=target_resolution.lower(),
video=await upload_video_to_comfyapi(cls, video),
duration=video.get_duration(),
),
)
if initial_res.code != 200:
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
response_model=TaskResultResponse,
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
poll_interval=10.0,
max_poll_attempts=480,
)
if final_response.code != 200:
raise ValueError(
f"Task processing failed with code={final_response.code} and message={final_response.message}"
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0]))
class WavespeedImageUpscaleNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="WavespeedImageUpscaleNode",
display_name="WaveSpeed Image Upscale",
category="api node/image/WaveSpeed",
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
inputs=[
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),
IO.Image.Input("image"),
IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$prices := {"seedvr2": 0.01, "ultimate": 0.06};
{"type":"usd", "usd": $lookup($prices, widgets.model)}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
image: Input.Image,
target_resolution: str,
) -> IO.NodeOutput:
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
if model == "SeedVR2":
model_path = "seedvr2/image"
else:
model_path = "ultimate-image-upscaler"
initial_res = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"),
response_model=TaskCreatedResponse,
data=SeedVR2ImageRequest(
target_resolution=target_resolution.lower(),
image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0],
),
)
if initial_res.code != 200:
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
response_model=TaskResultResponse,
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
poll_interval=10.0,
max_poll_attempts=480,
)
if final_response.code != 200:
raise ValueError(
f"Task processing failed with code={final_response.code} and message={final_response.message}"
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0]))
class WavespeedExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
WavespeedFlashVSRNode,
WavespeedImageUpscaleNode,
]
async def comfy_entrypoint() -> WavespeedExtension:
return WavespeedExtension()