mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 08:40:19 +08:00
Compare commits
5 Commits
aeced1a430
...
0662021851
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0662021851 | ||
|
|
c4a14df9a3 | ||
|
|
965d0ed509 | ||
|
|
ddc541ffda | ||
|
|
28db2757e1 |
@ -103,20 +103,10 @@ class AudioPreprocessor:
|
||||
return waveform
|
||||
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
|
||||
|
||||
@staticmethod
|
||||
def normalize_amplitude(
|
||||
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
|
||||
) -> torch.Tensor:
|
||||
waveform = waveform - waveform.mean(dim=2, keepdim=True)
|
||||
peak = torch.max(torch.abs(waveform)) + eps
|
||||
scale = peak.clamp(max=max_amplitude) / peak
|
||||
return waveform * scale
|
||||
|
||||
def waveform_to_mel(
|
||||
self, waveform: torch.Tensor, waveform_sample_rate: int, device
|
||||
) -> torch.Tensor:
|
||||
waveform = self.resample(waveform, waveform_sample_rate)
|
||||
waveform = self.normalize_amplitude(waveform)
|
||||
|
||||
mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=self.target_sample_rate,
|
||||
|
||||
@ -253,7 +253,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["image_model"] = "chroma_radiance"
|
||||
dit_config["in_channels"] = 3
|
||||
dit_config["out_channels"] = 3
|
||||
dit_config["patch_size"] = 16
|
||||
dit_config["patch_size"] = state_dict.get('{}img_in_patch.weight'.format(key_prefix)).size(dim=-1)
|
||||
dit_config["nerf_hidden_size"] = 64
|
||||
dit_config["nerf_mlp_ratio"] = 4
|
||||
dit_config["nerf_depth"] = 4
|
||||
|
||||
@ -13,6 +13,124 @@ import torch
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
class _ReentrantBytesIO(io.BytesIO):
|
||||
"""Read-only, seekable BytesIO-compatible view over shared immutable bytes."""
|
||||
|
||||
def __init__(self, data: bytes):
|
||||
super().__init__(b"") # Initialize base BytesIO with an empty buffer; we do not use its internal storage.
|
||||
if data is None:
|
||||
raise TypeError("data must be bytes, not None")
|
||||
self._data = data
|
||||
self._pos = 0
|
||||
self._len = len(data)
|
||||
|
||||
def getvalue(self) -> bytes:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
return self._data
|
||||
|
||||
def getbuffer(self) -> memoryview:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
return memoryview(self._data)
|
||||
|
||||
def readable(self) -> bool:
|
||||
return True
|
||||
|
||||
def writable(self) -> bool:
|
||||
return False
|
||||
|
||||
def seekable(self) -> bool:
|
||||
return True
|
||||
|
||||
def tell(self) -> int:
|
||||
return self._pos
|
||||
|
||||
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
if whence == io.SEEK_SET:
|
||||
new_pos = offset
|
||||
elif whence == io.SEEK_CUR:
|
||||
new_pos = self._pos + offset
|
||||
elif whence == io.SEEK_END:
|
||||
new_pos = self._len + offset
|
||||
else:
|
||||
raise ValueError(f"Invalid whence: {whence}")
|
||||
if new_pos < 0:
|
||||
raise ValueError("Negative seek position")
|
||||
self._pos = new_pos
|
||||
return self._pos
|
||||
|
||||
def readinto(self, b) -> int:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
mv = memoryview(b)
|
||||
if mv.readonly:
|
||||
raise TypeError("readinto() argument must be writable")
|
||||
mv = mv.cast("B")
|
||||
if self._pos >= self._len:
|
||||
return 0
|
||||
n = min(len(mv), self._len - self._pos)
|
||||
mv[:n] = self._data[self._pos:self._pos + n]
|
||||
self._pos += n
|
||||
return n
|
||||
|
||||
def readinto1(self, b) -> int:
|
||||
return self.readinto(b)
|
||||
|
||||
def read(self, size: int = -1) -> bytes:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
if size is None or size < 0:
|
||||
size = self._len - self._pos
|
||||
if self._pos >= self._len:
|
||||
return b""
|
||||
end = min(self._pos + size, self._len)
|
||||
out = self._data[self._pos:end]
|
||||
self._pos = end
|
||||
return out
|
||||
|
||||
def read1(self, size: int = -1) -> bytes:
|
||||
return self.read(size)
|
||||
|
||||
def readline(self, size: int = -1) -> bytes:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
if self._pos >= self._len:
|
||||
return b""
|
||||
end_limit = self._len if size is None or size < 0 else min(self._len, self._pos + size)
|
||||
nl = self._data.find(b"\n", self._pos, end_limit)
|
||||
end = (nl + 1) if nl != -1 else end_limit
|
||||
out = self._data[self._pos:end]
|
||||
self._pos = end
|
||||
return out
|
||||
|
||||
def readlines(self, hint: int = -1) -> list[bytes]:
|
||||
if self.closed:
|
||||
raise ValueError("I/O operation on closed file.")
|
||||
lines: list[bytes] = []
|
||||
total = 0
|
||||
while True:
|
||||
line = self.readline()
|
||||
if not line:
|
||||
break
|
||||
lines.append(line)
|
||||
total += len(line)
|
||||
if hint is not None and 0 <= hint <= total:
|
||||
break
|
||||
return lines
|
||||
|
||||
def write(self, b) -> int:
|
||||
raise io.UnsupportedOperation("not writable")
|
||||
|
||||
def writelines(self, lines) -> None:
|
||||
raise io.UnsupportedOperation("not writable")
|
||||
|
||||
def truncate(self, size: int | None = None) -> int:
|
||||
raise io.UnsupportedOperation("not writable")
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
"""
|
||||
A container's `format` may be a comma-separated list of formats.
|
||||
@ -57,21 +175,31 @@ class VideoFromFile(VideoInput):
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
__data: str | bytes
|
||||
|
||||
def __init__(self, file: str | io.BytesIO | bytes | bytearray | memoryview):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
if isinstance(file, str):
|
||||
self.__data = file
|
||||
elif isinstance(file, io.BytesIO):
|
||||
# Snapshot to immutable bytes once to ensure re-entrant, parallel-safe readers.
|
||||
self.__data = file.getbuffer().tobytes()
|
||||
elif isinstance(file, (bytes, bytearray, memoryview)):
|
||||
self.__data = bytes(file)
|
||||
else:
|
||||
raise TypeError(f"Unsupported video source type: {type(file)!r}")
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
Return the underlying file source for efficient streaming.
|
||||
This avoids unnecessary memory copies when the source is already a file path.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
return self.__file
|
||||
if isinstance(self.__data, str):
|
||||
return self.__data
|
||||
return _ReentrantBytesIO(self.__data)
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
@ -80,14 +208,12 @@ class VideoFromFile(VideoInput):
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
with av.open(self.get_stream_source(), mode="r") as container:
|
||||
for stream in container.streams:
|
||||
if stream.type == 'video':
|
||||
assert isinstance(stream, av.VideoStream)
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
raise ValueError(f"No video stream found in {self._source_label()}")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
@ -96,9 +222,7 @@ class VideoFromFile(VideoInput):
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
with av.open(self.get_stream_source(), mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
@ -119,17 +243,14 @@ class VideoFromFile(VideoInput):
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
raise ValueError(f"Could not determine duration for file '{self._source_label()}'")
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
"""
|
||||
Returns the number of frames in the video without materializing them as
|
||||
torch tensors.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
with av.open(self.get_stream_source(), mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# 1. Prefer the frames field if available
|
||||
if video_stream.frames and video_stream.frames > 0:
|
||||
@ -160,7 +281,7 @@ class VideoFromFile(VideoInput):
|
||||
frame_count += 1
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
raise ValueError(f"Could not determine frame count for file '{self._source_label()}'")
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
@ -168,10 +289,7 @@ class VideoFromFile(VideoInput):
|
||||
Returns the average frame rate of the video using container metadata
|
||||
without decoding all frames.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
with av.open(self.get_stream_source(), mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
|
||||
if video_stream.average_rate:
|
||||
@ -193,9 +311,7 @@ class VideoFromFile(VideoInput):
|
||||
Returns:
|
||||
Container format as string
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
with av.open(self.get_stream_source(), mode='r') as container:
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
@ -239,11 +355,8 @@ class VideoFromFile(VideoInput):
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
with av.open(self.get_stream_source(), mode='r') as container:
|
||||
return self.get_components_internal(container)
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
@ -252,9 +365,7 @@ class VideoFromFile(VideoInput):
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
with av.open(self.get_stream_source(), mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
reuse_streams = True
|
||||
@ -306,9 +417,12 @@ class VideoFromFile(VideoInput):
|
||||
def _get_first_video_stream(self, container: InputContainer):
|
||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||
if video_stream is None:
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
raise ValueError(f"No video stream found in file '{self._source_label()}'")
|
||||
return video_stream
|
||||
|
||||
def _source_label(self) -> str:
|
||||
return self.__data if isinstance(self.__data, str) else f"<in-memory video: {len(self.__data)} bytes>"
|
||||
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
|
||||
35
comfy_api_nodes/apis/wavespeed.py
Normal file
35
comfy_api_nodes/apis/wavespeed.py
Normal file
@ -0,0 +1,35 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SeedVR2ImageRequest(BaseModel):
|
||||
image: str = Field(...)
|
||||
target_resolution: str = Field(...)
|
||||
output_format: str = Field("png")
|
||||
enable_sync_mode: bool = Field(False)
|
||||
|
||||
|
||||
class FlashVSRRequest(BaseModel):
|
||||
target_resolution: str = Field(...)
|
||||
video: str = Field(...)
|
||||
duration: float = Field(...)
|
||||
|
||||
|
||||
class TaskCreatedDataResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
|
||||
class TaskCreatedResponse(BaseModel):
|
||||
code: int = Field(...)
|
||||
message: str = Field(...)
|
||||
data: TaskCreatedDataResponse | None = Field(None)
|
||||
|
||||
|
||||
class TaskResultDataResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
outputs: list[str] = Field([])
|
||||
|
||||
|
||||
class TaskResultResponse(BaseModel):
|
||||
code: int = Field(...)
|
||||
message: str = Field(...)
|
||||
data: TaskResultDataResponse | None = Field(None)
|
||||
178
comfy_api_nodes/nodes_wavespeed.py
Normal file
178
comfy_api_nodes/nodes_wavespeed.py
Normal file
@ -0,0 +1,178 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.wavespeed import (
|
||||
FlashVSRRequest,
|
||||
TaskCreatedResponse,
|
||||
TaskResultResponse,
|
||||
SeedVR2ImageRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_video_to_comfyapi,
|
||||
validate_container_format_is_mp4,
|
||||
validate_video_duration,
|
||||
upload_images_to_comfyapi,
|
||||
get_number_of_images,
|
||||
download_url_to_image_tensor,
|
||||
)
|
||||
|
||||
|
||||
class WavespeedFlashVSRNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="WavespeedFlashVSRNode",
|
||||
display_name="FlashVSR Video Upscale",
|
||||
category="api node/video/WaveSpeed",
|
||||
description="Fast, high-quality video upscaler that "
|
||||
"boosts resolution and restores clarity for low-resolution or blurry footage.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032};
|
||||
{
|
||||
"type":"usd",
|
||||
"usd": $lookup($price_for_1sec, widgets.target_resolution),
|
||||
"format":{"suffix": "/second", "approximate": true}
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
target_resolution: str,
|
||||
) -> IO.NodeOutput:
|
||||
validate_container_format_is_mp4(video)
|
||||
validate_video_duration(video, min_duration=5, max_duration=60 * 10)
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"),
|
||||
response_model=TaskCreatedResponse,
|
||||
data=FlashVSRRequest(
|
||||
target_resolution=target_resolution.lower(),
|
||||
video=await upload_video_to_comfyapi(cls, video),
|
||||
duration=video.get_duration(),
|
||||
),
|
||||
)
|
||||
if initial_res.code != 200:
|
||||
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
|
||||
response_model=TaskResultResponse,
|
||||
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
if final_response.code != 200:
|
||||
raise ValueError(
|
||||
f"Task processing failed with code={final_response.code} and message={final_response.message}"
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0]))
|
||||
|
||||
|
||||
class WavespeedImageUpscaleNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="WavespeedImageUpscaleNode",
|
||||
display_name="WaveSpeed Image Upscale",
|
||||
category="api node/image/WaveSpeed",
|
||||
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"seedvr2": 0.01, "ultimate": 0.06};
|
||||
{"type":"usd", "usd": $lookup($prices, widgets.model)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
target_resolution: str,
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
if model == "SeedVR2":
|
||||
model_path = "seedvr2/image"
|
||||
else:
|
||||
model_path = "ultimate-image-upscaler"
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"),
|
||||
response_model=TaskCreatedResponse,
|
||||
data=SeedVR2ImageRequest(
|
||||
target_resolution=target_resolution.lower(),
|
||||
image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0],
|
||||
),
|
||||
)
|
||||
if initial_res.code != 200:
|
||||
raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}")
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"),
|
||||
response_model=TaskResultResponse,
|
||||
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
if final_response.code != 200:
|
||||
raise ValueError(
|
||||
f"Task processing failed with code={final_response.code} and message={final_response.message}"
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0]))
|
||||
|
||||
|
||||
class WavespeedExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
WavespeedFlashVSRNode,
|
||||
WavespeedImageUpscaleNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> WavespeedExtension:
|
||||
return WavespeedExtension()
|
||||
Loading…
Reference in New Issue
Block a user