mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 11:32:31 +08:00
Compare commits
15 Commits
3d42682daa
...
42aad15a9b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42aad15a9b | ||
|
|
ba5bf3f1a8 | ||
|
|
c05a08ae66 | ||
|
|
de9ada6a41 | ||
|
|
37f711d4a1 | ||
|
|
f387379873 | ||
|
|
f650f91697 | ||
|
|
0864bcec00 | ||
|
|
fa8241f85e | ||
|
|
37c2a960cb | ||
|
|
4b37647ce1 | ||
|
|
ae7bf48331 | ||
|
|
7d2c369f45 | ||
|
|
e1558efbea | ||
|
|
803808b1b1 |
@ -19,7 +19,8 @@
|
||||
import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import threading
|
||||
import torch
|
||||
import sys
|
||||
import platform
|
||||
@ -650,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
@ -746,6 +747,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
|
||||
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
|
||||
with torch.inference_mode():
|
||||
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
soft_empty_cache()
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
|
||||
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
|
||||
#thread local. So exploit that to escape context
|
||||
if enables_dynamic_vram():
|
||||
t = threading.Thread(
|
||||
target=load_models_gpu_thread,
|
||||
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
)
|
||||
t.start()
|
||||
t.join()
|
||||
else:
|
||||
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
|
||||
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
|
||||
def load_model_gpu(model):
|
||||
return load_models_gpu([model])
|
||||
|
||||
@ -1112,11 +1133,11 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
||||
return None
|
||||
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
del STREAM_CAST_BUFFERS[offload_stream]
|
||||
del cast_buffer
|
||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||
torch.cuda.empty_cache()
|
||||
soft_empty_cache()
|
||||
with wf_context:
|
||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
@ -1132,9 +1153,7 @@ def reset_cast_buffers():
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
|
||||
torch.cuda.empty_cache()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
@ -1284,7 +1303,7 @@ def discard_cuda_async_error():
|
||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
except torch.AcceleratorError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
@ -1688,6 +1707,12 @@ def lora_compute_dtype(device):
|
||||
LORA_COMPUTE_DTYPES[device] = dtype
|
||||
return dtype
|
||||
|
||||
def synchronize():
|
||||
if is_intel_xpu():
|
||||
torch.xpu.synchronize()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
@ -1713,9 +1738,6 @@ def debug_memory_summary():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ -1597,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
if unpatch_weights:
|
||||
self.partially_unload_ram(1e32)
|
||||
self.partially_unload(None)
|
||||
self.partially_unload(None, 1e32)
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
|
||||
@ -34,6 +34,21 @@ class VideoInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = False,
|
||||
) -> VideoInput | None:
|
||||
"""
|
||||
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
|
||||
|
||||
Returns:
|
||||
A new VideoInput, or None if the result would have negative duration
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Get a streamable source for the video. This allows processing without
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Optional
|
||||
from .._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import numpy as np
|
||||
import math
|
||||
@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
self.__start_time = start_time
|
||||
self.__duration = duration
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
@ -96,6 +98,16 @@ class VideoFromFile(VideoInput):
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
raw_duration = self._get_raw_duration()
|
||||
if self.__start_time < 0:
|
||||
duration_from_start = min(raw_duration, -self.__start_time)
|
||||
else:
|
||||
duration_from_start = raw_duration - self.__start_time
|
||||
if self.__duration:
|
||||
return min(self.__duration, duration_from_start)
|
||||
return duration_from_start
|
||||
|
||||
def _get_raw_duration(self) -> float:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
@ -113,9 +125,13 @@ class VideoFromFile(VideoInput):
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
frame_iterator = (
|
||||
container.decode(video_stream)
|
||||
if video_stream.codec.capabilities & 0x100
|
||||
else container.demux(video_stream)
|
||||
)
|
||||
for packet in frame_iterator:
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
@ -131,36 +147,54 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# 1. Prefer the frames field if available
|
||||
if video_stream.frames and video_stream.frames > 0:
|
||||
# 1. Prefer the frames field if available and usable
|
||||
if (
|
||||
video_stream.frames
|
||||
and video_stream.frames > 0
|
||||
and not self.__start_time
|
||||
and not self.__duration
|
||||
):
|
||||
return int(video_stream.frames)
|
||||
|
||||
# 2. Try to estimate from duration and average_rate using only metadata
|
||||
if container.duration is not None and video_stream.average_rate:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
if (
|
||||
getattr(video_stream, "duration", None) is not None
|
||||
and getattr(video_stream, "time_base", None) is not None
|
||||
and video_stream.average_rate
|
||||
):
|
||||
duration_seconds = float(video_stream.duration * video_stream.time_base)
|
||||
raw_duration = float(video_stream.duration * video_stream.time_base)
|
||||
if self.__start_time < 0:
|
||||
duration_from_start = min(raw_duration, -self.__start_time)
|
||||
else:
|
||||
duration_from_start = raw_duration - self.__start_time
|
||||
duration_seconds = min(self.__duration, duration_from_start)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
# 3. Last resort: decode frames and count them (streaming)
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
if self.__start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
frame_count = 1
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
frame_iterator = (
|
||||
container.decode(video_stream)
|
||||
if video_stream.codec.capabilities & 0x100
|
||||
else container.demux(video_stream)
|
||||
)
|
||||
for frame in frame_iterator:
|
||||
if frame.pts >= start_pts:
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
|
||||
for frame in frame_iterator:
|
||||
if frame.pts >= end_pts:
|
||||
break
|
||||
frame_count += 1
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
@ -199,9 +233,21 @@ class VideoFromFile(VideoInput):
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
if self.__start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
# Get video frames
|
||||
frames = []
|
||||
for frame in container.decode(video=0):
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
for frame in container.decode(video_stream):
|
||||
if frame.pts < start_pts:
|
||||
continue
|
||||
if self.__duration and frame.pts >= end_pts:
|
||||
break
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
@ -209,31 +255,44 @@ class VideoFromFile(VideoInput):
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
# Use last stream for consistency
|
||||
if len(container.streams.audio):
|
||||
audio_stream = container.streams.audio[-1]
|
||||
audio_frames = []
|
||||
resample = av.audio.resampler.AudioResampler(format='fltp').resample
|
||||
frames = itertools.chain.from_iterable(
|
||||
map(resample, container.decode(audio_stream))
|
||||
)
|
||||
|
||||
has_first_frame = False
|
||||
for frame in frames:
|
||||
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
||||
if to_skip < frame.samples:
|
||||
has_first_frame = True
|
||||
break
|
||||
if has_first_frame:
|
||||
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||
|
||||
for frame in frames:
|
||||
if frame.time > start_time + self.__duration:
|
||||
break
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
if self.__duration:
|
||||
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
|
||||
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
|
||||
})
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
@ -250,7 +309,7 @@ class VideoFromFile(VideoInput):
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
@ -262,15 +321,14 @@ class VideoFromFile(VideoInput):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
if self.__start_time or self.__duration:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
path, format=format, codec=codec, metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
@ -304,10 +362,21 @@ class VideoFromFile(VideoInput):
|
||||
output_container.mux(packet)
|
||||
|
||||
def _get_first_video_stream(self, container: InputContainer):
|
||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||
if video_stream is None:
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
return video_stream
|
||||
if len(container.streams.video):
|
||||
return container.streams.video[0]
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def as_trimmed(
|
||||
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
|
||||
) -> VideoInput | None:
|
||||
trimmed = VideoFromFile(
|
||||
self.get_stream_source(),
|
||||
start_time=start_time + self.__start_time,
|
||||
duration=duration + self.__duration,
|
||||
)
|
||||
if trimmed.get_duration() < duration and strict_duration:
|
||||
return None
|
||||
return trimmed
|
||||
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput):
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
frame_rate=self.__components.frame_rate,
|
||||
)
|
||||
|
||||
def save_to(
|
||||
@ -330,7 +399,7 @@ class VideoFromComponents(VideoInput):
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
@ -381,3 +450,14 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
# Flush encoder
|
||||
output.mux(audio_stream.encode(None))
|
||||
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = True,
|
||||
) -> VideoInput | None:
|
||||
if self.get_duration() < start_time + duration:
|
||||
return None
|
||||
#TODO Consider tracking duration and trimming at time of save?
|
||||
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
|
||||
|
||||
51
comfy_api_nodes/apis/hitpaw.py
Normal file
51
comfy_api_nodes/apis/hitpaw.py
Normal file
@ -0,0 +1,51 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class InputVideoModel(TypedDict):
|
||||
model: str
|
||||
resolution: str
|
||||
|
||||
|
||||
class ImageEnhanceTaskCreateRequest(BaseModel):
|
||||
model_name: str = Field(...)
|
||||
img_url: str = Field(...)
|
||||
extension: str = Field(".png")
|
||||
exif: bool = Field(False)
|
||||
DPI: int | None = Field(None)
|
||||
|
||||
|
||||
class VideoEnhanceTaskCreateRequest(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
extension: str = Field(".mp4")
|
||||
model_name: str | None = Field(...)
|
||||
resolution: list[int] = Field(..., description="Target resolution [width, height]")
|
||||
original_resolution: list[int] = Field(..., description="Original video resolution [width, height]")
|
||||
|
||||
|
||||
class TaskCreateDataResponse(BaseModel):
|
||||
job_id: str = Field(...)
|
||||
consume_coins: int | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusPollRequest(BaseModel):
|
||||
job_id: str = Field(...)
|
||||
|
||||
|
||||
class TaskCreateResponse(BaseModel):
|
||||
code: int = Field(...)
|
||||
message: str = Field(...)
|
||||
data: TaskCreateDataResponse | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusDataResponse(BaseModel):
|
||||
job_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
res_url: str = Field("")
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
code: int = Field(...)
|
||||
message: str = Field(...)
|
||||
data: TaskStatusDataResponse = Field(...)
|
||||
342
comfy_api_nodes/nodes_hitpaw.py
Normal file
342
comfy_api_nodes/nodes_hitpaw.py
Normal file
@ -0,0 +1,342 @@
|
||||
import math
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.hitpaw import (
|
||||
ImageEnhanceTaskCreateRequest,
|
||||
InputVideoModel,
|
||||
TaskCreateDataResponse,
|
||||
TaskCreateResponse,
|
||||
TaskStatusPollRequest,
|
||||
TaskStatusResponse,
|
||||
VideoEnhanceTaskCreateRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
downscale_image_tensor,
|
||||
get_image_dimensions,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
VIDEO_MODELS_MODELS_MAP = {
|
||||
"Portrait Restore Model (1x)": "portrait_restore_1x",
|
||||
"Portrait Restore Model (2x)": "portrait_restore_2x",
|
||||
"General Restore Model (1x)": "general_restore_1x",
|
||||
"General Restore Model (2x)": "general_restore_2x",
|
||||
"General Restore Model (4x)": "general_restore_4x",
|
||||
"Ultra HD Model (2x)": "ultrahd_restore_2x",
|
||||
"Generative Model (1x)": "generative_1x",
|
||||
}
|
||||
|
||||
# Resolution name to target dimension (shorter side) in pixels
|
||||
RESOLUTION_TARGET_MAP = {
|
||||
"720p": 720,
|
||||
"1080p": 1080,
|
||||
"2K/QHD": 1440,
|
||||
"4K/UHD": 2160,
|
||||
"8K": 4320,
|
||||
}
|
||||
|
||||
# Square (1:1) resolutions use standard square dimensions
|
||||
RESOLUTION_SQUARE_MAP = {
|
||||
"720p": 720,
|
||||
"1080p": 1080,
|
||||
"2K/QHD": 1440,
|
||||
"4K/UHD": 2048, # DCI 4K square
|
||||
"8K": 4096, # DCI 8K square
|
||||
}
|
||||
|
||||
# Models with limited resolution support (no 8K)
|
||||
LIMITED_RESOLUTION_MODELS = {"Generative Model (1x)"}
|
||||
|
||||
# Resolution options for different model types
|
||||
RESOLUTIONS_LIMITED = ["original", "720p", "1080p", "2K/QHD", "4K/UHD"]
|
||||
RESOLUTIONS_FULL = ["original", "720p", "1080p", "2K/QHD", "4K/UHD", "8K"]
|
||||
|
||||
# Maximum output resolution in pixels
|
||||
MAX_PIXELS_GENERATIVE = 32_000_000
|
||||
MAX_MP_GENERATIVE = MAX_PIXELS_GENERATIVE // 1_000_000
|
||||
|
||||
|
||||
class HitPawGeneralImageEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="HitPawGeneralImageEnhance",
|
||||
display_name="HitPaw General Image Enhance",
|
||||
category="api node/image/HitPaw",
|
||||
description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
|
||||
f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["generative_portrait", "generative"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("upscale_factor", options=[1, 2, 4]),
|
||||
IO.Boolean.Input(
|
||||
"auto_downscale",
|
||||
default=False,
|
||||
tooltip="Automatically downscale input image if output would exceed the limit.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {
|
||||
"generative_portrait": {"min": 0.02, "max": 0.06},
|
||||
"generative": {"min": 0.05, "max": 0.15}
|
||||
};
|
||||
$price := $lookup($prices, widgets.model);
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $price.min,
|
||||
"max_usd": $price.max
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: Input.Image,
|
||||
upscale_factor: int,
|
||||
auto_downscale: bool,
|
||||
) -> IO.NodeOutput:
|
||||
height, width = get_image_dimensions(image)
|
||||
requested_scale = upscale_factor
|
||||
output_pixels = height * width * requested_scale * requested_scale
|
||||
if output_pixels > MAX_PIXELS_GENERATIVE:
|
||||
if auto_downscale:
|
||||
input_pixels = width * height
|
||||
scale = 1
|
||||
max_input_pixels = MAX_PIXELS_GENERATIVE
|
||||
|
||||
for candidate in [4, 2, 1]:
|
||||
if candidate > requested_scale:
|
||||
continue
|
||||
scale_output_pixels = input_pixels * candidate * candidate
|
||||
if scale_output_pixels <= MAX_PIXELS_GENERATIVE:
|
||||
scale = candidate
|
||||
max_input_pixels = None
|
||||
break
|
||||
# Check if we can downscale input by at most 2x to fit
|
||||
downscale_ratio = math.sqrt(scale_output_pixels / MAX_PIXELS_GENERATIVE)
|
||||
if downscale_ratio <= 2.0:
|
||||
scale = candidate
|
||||
max_input_pixels = MAX_PIXELS_GENERATIVE // (candidate * candidate)
|
||||
break
|
||||
|
||||
if max_input_pixels is not None:
|
||||
image = downscale_image_tensor(image, total_pixels=max_input_pixels)
|
||||
upscale_factor = scale
|
||||
else:
|
||||
output_width = width * requested_scale
|
||||
output_height = height * requested_scale
|
||||
raise ValueError(
|
||||
f"Output size ({output_width}x{output_height} = {output_pixels:,} pixels) "
|
||||
f"exceeds maximum allowed size of {MAX_PIXELS_GENERATIVE:,} pixels ({MAX_MP_GENERATIVE}MP). "
|
||||
f"Enable auto_downscale or use a smaller input image or a lower upscale factor."
|
||||
)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/hitpaw/api/photo-enhancer", method="POST"),
|
||||
response_model=TaskCreateResponse,
|
||||
data=ImageEnhanceTaskCreateRequest(
|
||||
model_name=f"{model}_{upscale_factor}x",
|
||||
img_url=await upload_image_to_comfyapi(cls, image, total_pixels=None),
|
||||
),
|
||||
wait_label="Creating task",
|
||||
final_label_on_success="Task created",
|
||||
)
|
||||
if initial_res.code != 200:
|
||||
raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
|
||||
request_price = initial_res.data.consume_coins / 1000
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
|
||||
data=TaskCreateDataResponse(job_id=initial_res.data.job_id),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda x: x.data.status,
|
||||
price_extractor=lambda x: request_price,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url))
|
||||
|
||||
|
||||
class HitPawVideoEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
model_options = []
|
||||
for model_name in VIDEO_MODELS_MODELS_MAP:
|
||||
if model_name in LIMITED_RESOLUTION_MODELS:
|
||||
resolutions = RESOLUTIONS_LIMITED
|
||||
else:
|
||||
resolutions = RESOLUTIONS_FULL
|
||||
model_options.append(
|
||||
IO.DynamicCombo.Option(
|
||||
model_name,
|
||||
[IO.Combo.Input("resolution", options=resolutions)],
|
||||
)
|
||||
)
|
||||
|
||||
return IO.Schema(
|
||||
node_id="HitPawVideoEnhance",
|
||||
display_name="HitPaw Video Enhance",
|
||||
category="api node/video/HitPaw",
|
||||
description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
|
||||
"Prices shown are per second of video.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input("model", options=model_options),
|
||||
IO.Video.Input("video"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$m := $lookup(widgets, "model");
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$standard_model_prices := {
|
||||
"original": {"min": 0.01, "max": 0.198},
|
||||
"720p": {"min": 0.01, "max": 0.06},
|
||||
"1080p": {"min": 0.015, "max": 0.09},
|
||||
"2k/qhd": {"min": 0.02, "max": 0.117},
|
||||
"4k/uhd": {"min": 0.025, "max": 0.152},
|
||||
"8k": {"min": 0.033, "max": 0.198}
|
||||
};
|
||||
$ultra_hd_model_prices := {
|
||||
"original": {"min": 0.015, "max": 0.264},
|
||||
"720p": {"min": 0.015, "max": 0.092},
|
||||
"1080p": {"min": 0.02, "max": 0.12},
|
||||
"2k/qhd": {"min": 0.026, "max": 0.156},
|
||||
"4k/uhd": {"min": 0.034, "max": 0.203},
|
||||
"8k": {"min": 0.044, "max": 0.264}
|
||||
};
|
||||
$generative_model_prices := {
|
||||
"original": {"min": 0.015, "max": 0.338},
|
||||
"720p": {"min": 0.008, "max": 0.090},
|
||||
"1080p": {"min": 0.05, "max": 0.15},
|
||||
"2k/qhd": {"min": 0.038, "max": 0.225},
|
||||
"4k/uhd": {"min": 0.056, "max": 0.338}
|
||||
};
|
||||
$prices := $contains($m, "ultra hd") ? $ultra_hd_model_prices :
|
||||
$contains($m, "generative") ? $generative_model_prices :
|
||||
$standard_model_prices;
|
||||
$price := $lookup($prices, $res);
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $price.min,
|
||||
"max_usd": $price.max,
|
||||
"format": {"approximate": true, "suffix": "/second"}
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: InputVideoModel,
|
||||
video: Input.Video,
|
||||
) -> IO.NodeOutput:
|
||||
validate_video_duration(video, min_duration=0.5, max_duration=60 * 60)
|
||||
resolution = model["resolution"]
|
||||
src_width, src_height = video.get_dimensions()
|
||||
|
||||
if resolution == "original":
|
||||
output_width = src_width
|
||||
output_height = src_height
|
||||
else:
|
||||
if src_width == src_height:
|
||||
target_size = RESOLUTION_SQUARE_MAP[resolution]
|
||||
if target_size < src_width:
|
||||
raise ValueError(
|
||||
f"Selected resolution {resolution} ({target_size}x{target_size}) is smaller than "
|
||||
f"the input video ({src_width}x{src_height}). Please select a higher resolution or 'original'."
|
||||
)
|
||||
output_width = target_size
|
||||
output_height = target_size
|
||||
else:
|
||||
min_dimension = min(src_width, src_height)
|
||||
target_size = RESOLUTION_TARGET_MAP[resolution]
|
||||
if target_size < min_dimension:
|
||||
raise ValueError(
|
||||
f"Selected resolution {resolution} ({target_size}p) is smaller than "
|
||||
f"the input video's shorter dimension ({min_dimension}p). "
|
||||
f"Please select a higher resolution or 'original'."
|
||||
)
|
||||
if src_width > src_height:
|
||||
output_height = target_size
|
||||
output_width = int(target_size * (src_width / src_height))
|
||||
else:
|
||||
output_width = target_size
|
||||
output_height = int(target_size * (src_height / src_width))
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/hitpaw/api/video-enhancer", method="POST"),
|
||||
response_model=TaskCreateResponse,
|
||||
data=VideoEnhanceTaskCreateRequest(
|
||||
video_url=await upload_video_to_comfyapi(cls, video),
|
||||
resolution=[output_width, output_height],
|
||||
original_resolution=[src_width, src_height],
|
||||
model_name=VIDEO_MODELS_MODELS_MAP[model["model"]],
|
||||
),
|
||||
wait_label="Creating task",
|
||||
final_label_on_success="Task created",
|
||||
)
|
||||
request_price = initial_res.data.consume_coins / 1000
|
||||
if initial_res.code != 200:
|
||||
raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
|
||||
data=TaskStatusPollRequest(job_id=initial_res.data.job_id),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda x: x.data.status,
|
||||
price_extractor=lambda x: request_price,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=320,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url))
|
||||
|
||||
|
||||
class HitPawExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
HitPawGeneralImageEnhance,
|
||||
HitPawVideoEnhance,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> HitPawExtension:
|
||||
return HitPawExtension()
|
||||
@ -94,7 +94,7 @@ async def upload_image_to_comfyapi(
|
||||
*,
|
||||
mime_type: str | None = None,
|
||||
wait_label: str | None = "Uploading",
|
||||
total_pixels: int = 2048 * 2048,
|
||||
total_pixels: int | None = 2048 * 2048,
|
||||
) -> str:
|
||||
"""Uploads a single image to ComfyUI API and returns its download URL."""
|
||||
return (
|
||||
|
||||
@ -202,6 +202,56 @@ class LoadVideo(io.ComfyNode):
|
||||
|
||||
return True
|
||||
|
||||
class VideoSlice(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Video Slice",
|
||||
display_name="Video Slice",
|
||||
search_aliases=[
|
||||
"trim video duration",
|
||||
"skip first frames",
|
||||
"frame load cap",
|
||||
"start time",
|
||||
],
|
||||
category="image/video",
|
||||
inputs=[
|
||||
io.Video.Input("video"),
|
||||
io.Float.Input(
|
||||
"start_time",
|
||||
default=0.0,
|
||||
max=1e5,
|
||||
min=-1e5,
|
||||
step=0.001,
|
||||
tooltip="Start time in seconds",
|
||||
),
|
||||
io.Float.Input(
|
||||
"duration",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
step=0.001,
|
||||
tooltip="Duration in seconds, or 0 for unlimited duration",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"strict_duration",
|
||||
default=False,
|
||||
tooltip="If True, when the specified duration is not possible, an error will be raised.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
|
||||
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
|
||||
if trimmed is not None:
|
||||
return io.NodeOutput(trimmed)
|
||||
raise ValueError(
|
||||
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
|
||||
)
|
||||
|
||||
|
||||
class VideoExtension(ComfyExtension):
|
||||
@override
|
||||
@ -212,6 +262,7 @@ class VideoExtension(ComfyExtension):
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
LoadVideo,
|
||||
VideoSlice,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> VideoExtension:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user