mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 16:32:34 +08:00
Compare commits
1 Commits
01cd571aee
...
84ae070c88
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
84ae070c88 |
@ -19,8 +19,7 @@
|
|||||||
import psutil
|
import psutil
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import threading
|
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
import platform
|
import platform
|
||||||
@ -651,7 +650,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
return unloaded_models
|
return unloaded_models
|
||||||
|
|
||||||
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
@ -747,26 +746,6 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
|
|||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
return
|
return
|
||||||
|
|
||||||
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
|
|
||||||
with torch.inference_mode():
|
|
||||||
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
|
||||||
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
|
|
||||||
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
|
|
||||||
#thread local. So exploit that to escape context
|
|
||||||
if enables_dynamic_vram():
|
|
||||||
t = threading.Thread(
|
|
||||||
target=load_models_gpu_thread,
|
|
||||||
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
|
||||||
)
|
|
||||||
t.start()
|
|
||||||
t.join()
|
|
||||||
else:
|
|
||||||
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
|
|
||||||
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
|
||||||
|
|
||||||
def load_model_gpu(model):
|
def load_model_gpu(model):
|
||||||
return load_models_gpu([model])
|
return load_models_gpu([model])
|
||||||
|
|
||||||
@ -1133,11 +1112,11 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
|||||||
return None
|
return None
|
||||||
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
|
||||||
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
#I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
|
||||||
synchronize()
|
torch.cuda.synchronize()
|
||||||
del STREAM_CAST_BUFFERS[offload_stream]
|
del STREAM_CAST_BUFFERS[offload_stream]
|
||||||
del cast_buffer
|
del cast_buffer
|
||||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||||
soft_empty_cache()
|
torch.cuda.empty_cache()
|
||||||
with wf_context:
|
with wf_context:
|
||||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||||
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||||
@ -1153,7 +1132,9 @@ def reset_cast_buffers():
|
|||||||
for offload_stream in STREAM_CAST_BUFFERS:
|
for offload_stream in STREAM_CAST_BUFFERS:
|
||||||
offload_stream.synchronize()
|
offload_stream.synchronize()
|
||||||
STREAM_CAST_BUFFERS.clear()
|
STREAM_CAST_BUFFERS.clear()
|
||||||
soft_empty_cache()
|
if comfy.memory_management.aimdo_allocator is None:
|
||||||
|
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
stream_counter = stream_counters.get(device, 0)
|
stream_counter = stream_counters.get(device, 0)
|
||||||
@ -1303,7 +1284,7 @@ def discard_cuda_async_error():
|
|||||||
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||||
_ = a + b
|
_ = a + b
|
||||||
synchronize()
|
torch.cuda.synchronize()
|
||||||
except torch.AcceleratorError:
|
except torch.AcceleratorError:
|
||||||
#Dump it! We already know about it from the synchronous return
|
#Dump it! We already know about it from the synchronous return
|
||||||
pass
|
pass
|
||||||
@ -1707,12 +1688,6 @@ def lora_compute_dtype(device):
|
|||||||
LORA_COMPUTE_DTYPES[device] = dtype
|
LORA_COMPUTE_DTYPES[device] = dtype
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
def synchronize():
|
|
||||||
if is_intel_xpu():
|
|
||||||
torch.xpu.synchronize()
|
|
||||||
elif torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
@ -1738,6 +1713,9 @@ def debug_memory_summary():
|
|||||||
return torch.cuda.memory.memory_summary()
|
return torch.cuda.memory.memory_summary()
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
#TODO: might be cleaner to put this somewhere else
|
||||||
|
import threading
|
||||||
|
|
||||||
class InterruptProcessingException(Exception):
|
class InterruptProcessingException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -1597,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.partially_unload_ram(1e32)
|
self.partially_unload_ram(1e32)
|
||||||
self.partially_unload(None, 1e32)
|
self.partially_unload(None)
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
assert not force_patch_weights #See above
|
assert not force_patch_weights #See above
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import torch
|
|||||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
|||||||
@ -118,7 +118,7 @@ class MistralTokenizerClass:
|
|||||||
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
||||||
super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"tekken_model": self.tekken_data}
|
return {"tekken_model": self.tekken_data}
|
||||||
@ -176,12 +176,12 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
|
|||||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
class Qwen3Tokenizer8B(sd1_clip.SDTokenizer):
|
class Qwen3Tokenizer8B(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
class KleinTokenizer(sd1_clip.SD1Tokenizer):
|
class KleinTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"):
|
def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"):
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import os
|
|||||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
|
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
|||||||
@ -1,51 +0,0 @@
|
|||||||
from typing import TypedDict
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class InputVideoModel(TypedDict):
|
|
||||||
model: str
|
|
||||||
resolution: str
|
|
||||||
|
|
||||||
|
|
||||||
class ImageEnhanceTaskCreateRequest(BaseModel):
|
|
||||||
model_name: str = Field(...)
|
|
||||||
img_url: str = Field(...)
|
|
||||||
extension: str = Field(".png")
|
|
||||||
exif: bool = Field(False)
|
|
||||||
DPI: int | None = Field(None)
|
|
||||||
|
|
||||||
|
|
||||||
class VideoEnhanceTaskCreateRequest(BaseModel):
|
|
||||||
video_url: str = Field(...)
|
|
||||||
extension: str = Field(".mp4")
|
|
||||||
model_name: str | None = Field(...)
|
|
||||||
resolution: list[int] = Field(..., description="Target resolution [width, height]")
|
|
||||||
original_resolution: list[int] = Field(..., description="Original video resolution [width, height]")
|
|
||||||
|
|
||||||
|
|
||||||
class TaskCreateDataResponse(BaseModel):
|
|
||||||
job_id: str = Field(...)
|
|
||||||
consume_coins: int | None = Field(None)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusPollRequest(BaseModel):
|
|
||||||
job_id: str = Field(...)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskCreateResponse(BaseModel):
|
|
||||||
code: int = Field(...)
|
|
||||||
message: str = Field(...)
|
|
||||||
data: TaskCreateDataResponse | None = Field(None)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusDataResponse(BaseModel):
|
|
||||||
job_id: str = Field(...)
|
|
||||||
status: str = Field(...)
|
|
||||||
res_url: str = Field("")
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusResponse(BaseModel):
|
|
||||||
code: int = Field(...)
|
|
||||||
message: str = Field(...)
|
|
||||||
data: TaskStatusDataResponse = Field(...)
|
|
||||||
@ -1,342 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from comfy_api.latest import IO, ComfyExtension, Input
|
|
||||||
from comfy_api_nodes.apis.hitpaw import (
|
|
||||||
ImageEnhanceTaskCreateRequest,
|
|
||||||
InputVideoModel,
|
|
||||||
TaskCreateDataResponse,
|
|
||||||
TaskCreateResponse,
|
|
||||||
TaskStatusPollRequest,
|
|
||||||
TaskStatusResponse,
|
|
||||||
VideoEnhanceTaskCreateRequest,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.util import (
|
|
||||||
ApiEndpoint,
|
|
||||||
download_url_to_image_tensor,
|
|
||||||
download_url_to_video_output,
|
|
||||||
downscale_image_tensor,
|
|
||||||
get_image_dimensions,
|
|
||||||
poll_op,
|
|
||||||
sync_op,
|
|
||||||
upload_image_to_comfyapi,
|
|
||||||
upload_video_to_comfyapi,
|
|
||||||
validate_video_duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
VIDEO_MODELS_MODELS_MAP = {
|
|
||||||
"Portrait Restore Model (1x)": "portrait_restore_1x",
|
|
||||||
"Portrait Restore Model (2x)": "portrait_restore_2x",
|
|
||||||
"General Restore Model (1x)": "general_restore_1x",
|
|
||||||
"General Restore Model (2x)": "general_restore_2x",
|
|
||||||
"General Restore Model (4x)": "general_restore_4x",
|
|
||||||
"Ultra HD Model (2x)": "ultrahd_restore_2x",
|
|
||||||
"Generative Model (1x)": "generative_1x",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Resolution name to target dimension (shorter side) in pixels
|
|
||||||
RESOLUTION_TARGET_MAP = {
|
|
||||||
"720p": 720,
|
|
||||||
"1080p": 1080,
|
|
||||||
"2K/QHD": 1440,
|
|
||||||
"4K/UHD": 2160,
|
|
||||||
"8K": 4320,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Square (1:1) resolutions use standard square dimensions
|
|
||||||
RESOLUTION_SQUARE_MAP = {
|
|
||||||
"720p": 720,
|
|
||||||
"1080p": 1080,
|
|
||||||
"2K/QHD": 1440,
|
|
||||||
"4K/UHD": 2048, # DCI 4K square
|
|
||||||
"8K": 4096, # DCI 8K square
|
|
||||||
}
|
|
||||||
|
|
||||||
# Models with limited resolution support (no 8K)
|
|
||||||
LIMITED_RESOLUTION_MODELS = {"Generative Model (1x)"}
|
|
||||||
|
|
||||||
# Resolution options for different model types
|
|
||||||
RESOLUTIONS_LIMITED = ["original", "720p", "1080p", "2K/QHD", "4K/UHD"]
|
|
||||||
RESOLUTIONS_FULL = ["original", "720p", "1080p", "2K/QHD", "4K/UHD", "8K"]
|
|
||||||
|
|
||||||
# Maximum output resolution in pixels
|
|
||||||
MAX_PIXELS_GENERATIVE = 32_000_000
|
|
||||||
MAX_MP_GENERATIVE = MAX_PIXELS_GENERATIVE // 1_000_000
|
|
||||||
|
|
||||||
|
|
||||||
class HitPawGeneralImageEnhance(IO.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="HitPawGeneralImageEnhance",
|
|
||||||
display_name="HitPaw General Image Enhance",
|
|
||||||
category="api node/image/HitPaw",
|
|
||||||
description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
|
|
||||||
f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
|
|
||||||
inputs=[
|
|
||||||
IO.Combo.Input("model", options=["generative_portrait", "generative"]),
|
|
||||||
IO.Image.Input("image"),
|
|
||||||
IO.Combo.Input("upscale_factor", options=[1, 2, 4]),
|
|
||||||
IO.Boolean.Input(
|
|
||||||
"auto_downscale",
|
|
||||||
default=False,
|
|
||||||
tooltip="Automatically downscale input image if output would exceed the limit.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
IO.Image.Output(),
|
|
||||||
],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
price_badge=IO.PriceBadge(
|
|
||||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
|
||||||
expr="""
|
|
||||||
(
|
|
||||||
$prices := {
|
|
||||||
"generative_portrait": {"min": 0.02, "max": 0.06},
|
|
||||||
"generative": {"min": 0.05, "max": 0.15}
|
|
||||||
};
|
|
||||||
$price := $lookup($prices, widgets.model);
|
|
||||||
{
|
|
||||||
"type": "range_usd",
|
|
||||||
"min_usd": $price.min,
|
|
||||||
"max_usd": $price.max
|
|
||||||
}
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
model: str,
|
|
||||||
image: Input.Image,
|
|
||||||
upscale_factor: int,
|
|
||||||
auto_downscale: bool,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
height, width = get_image_dimensions(image)
|
|
||||||
requested_scale = upscale_factor
|
|
||||||
output_pixels = height * width * requested_scale * requested_scale
|
|
||||||
if output_pixels > MAX_PIXELS_GENERATIVE:
|
|
||||||
if auto_downscale:
|
|
||||||
input_pixels = width * height
|
|
||||||
scale = 1
|
|
||||||
max_input_pixels = MAX_PIXELS_GENERATIVE
|
|
||||||
|
|
||||||
for candidate in [4, 2, 1]:
|
|
||||||
if candidate > requested_scale:
|
|
||||||
continue
|
|
||||||
scale_output_pixels = input_pixels * candidate * candidate
|
|
||||||
if scale_output_pixels <= MAX_PIXELS_GENERATIVE:
|
|
||||||
scale = candidate
|
|
||||||
max_input_pixels = None
|
|
||||||
break
|
|
||||||
# Check if we can downscale input by at most 2x to fit
|
|
||||||
downscale_ratio = math.sqrt(scale_output_pixels / MAX_PIXELS_GENERATIVE)
|
|
||||||
if downscale_ratio <= 2.0:
|
|
||||||
scale = candidate
|
|
||||||
max_input_pixels = MAX_PIXELS_GENERATIVE // (candidate * candidate)
|
|
||||||
break
|
|
||||||
|
|
||||||
if max_input_pixels is not None:
|
|
||||||
image = downscale_image_tensor(image, total_pixels=max_input_pixels)
|
|
||||||
upscale_factor = scale
|
|
||||||
else:
|
|
||||||
output_width = width * requested_scale
|
|
||||||
output_height = height * requested_scale
|
|
||||||
raise ValueError(
|
|
||||||
f"Output size ({output_width}x{output_height} = {output_pixels:,} pixels) "
|
|
||||||
f"exceeds maximum allowed size of {MAX_PIXELS_GENERATIVE:,} pixels ({MAX_MP_GENERATIVE}MP). "
|
|
||||||
f"Enable auto_downscale or use a smaller input image or a lower upscale factor."
|
|
||||||
)
|
|
||||||
|
|
||||||
initial_res = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path="/proxy/hitpaw/api/photo-enhancer", method="POST"),
|
|
||||||
response_model=TaskCreateResponse,
|
|
||||||
data=ImageEnhanceTaskCreateRequest(
|
|
||||||
model_name=f"{model}_{upscale_factor}x",
|
|
||||||
img_url=await upload_image_to_comfyapi(cls, image, total_pixels=None),
|
|
||||||
),
|
|
||||||
wait_label="Creating task",
|
|
||||||
final_label_on_success="Task created",
|
|
||||||
)
|
|
||||||
if initial_res.code != 200:
|
|
||||||
raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
|
|
||||||
request_price = initial_res.data.consume_coins / 1000
|
|
||||||
final_response = await poll_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
|
|
||||||
data=TaskCreateDataResponse(job_id=initial_res.data.job_id),
|
|
||||||
response_model=TaskStatusResponse,
|
|
||||||
status_extractor=lambda x: x.data.status,
|
|
||||||
price_extractor=lambda x: request_price,
|
|
||||||
poll_interval=10.0,
|
|
||||||
max_poll_attempts=480,
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url))
|
|
||||||
|
|
||||||
|
|
||||||
class HitPawVideoEnhance(IO.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
model_options = []
|
|
||||||
for model_name in VIDEO_MODELS_MODELS_MAP:
|
|
||||||
if model_name in LIMITED_RESOLUTION_MODELS:
|
|
||||||
resolutions = RESOLUTIONS_LIMITED
|
|
||||||
else:
|
|
||||||
resolutions = RESOLUTIONS_FULL
|
|
||||||
model_options.append(
|
|
||||||
IO.DynamicCombo.Option(
|
|
||||||
model_name,
|
|
||||||
[IO.Combo.Input("resolution", options=resolutions)],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="HitPawVideoEnhance",
|
|
||||||
display_name="HitPaw Video Enhance",
|
|
||||||
category="api node/video/HitPaw",
|
|
||||||
description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
|
|
||||||
"Prices shown are per second of video.",
|
|
||||||
inputs=[
|
|
||||||
IO.DynamicCombo.Input("model", options=model_options),
|
|
||||||
IO.Video.Input("video"),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
IO.Video.Output(),
|
|
||||||
],
|
|
||||||
hidden=[
|
|
||||||
IO.Hidden.auth_token_comfy_org,
|
|
||||||
IO.Hidden.api_key_comfy_org,
|
|
||||||
IO.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
is_api_node=True,
|
|
||||||
price_badge=IO.PriceBadge(
|
|
||||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
|
|
||||||
expr="""
|
|
||||||
(
|
|
||||||
$m := $lookup(widgets, "model");
|
|
||||||
$res := $lookup(widgets, "model.resolution");
|
|
||||||
$standard_model_prices := {
|
|
||||||
"original": {"min": 0.01, "max": 0.198},
|
|
||||||
"720p": {"min": 0.01, "max": 0.06},
|
|
||||||
"1080p": {"min": 0.015, "max": 0.09},
|
|
||||||
"2k/qhd": {"min": 0.02, "max": 0.117},
|
|
||||||
"4k/uhd": {"min": 0.025, "max": 0.152},
|
|
||||||
"8k": {"min": 0.033, "max": 0.198}
|
|
||||||
};
|
|
||||||
$ultra_hd_model_prices := {
|
|
||||||
"original": {"min": 0.015, "max": 0.264},
|
|
||||||
"720p": {"min": 0.015, "max": 0.092},
|
|
||||||
"1080p": {"min": 0.02, "max": 0.12},
|
|
||||||
"2k/qhd": {"min": 0.026, "max": 0.156},
|
|
||||||
"4k/uhd": {"min": 0.034, "max": 0.203},
|
|
||||||
"8k": {"min": 0.044, "max": 0.264}
|
|
||||||
};
|
|
||||||
$generative_model_prices := {
|
|
||||||
"original": {"min": 0.015, "max": 0.338},
|
|
||||||
"720p": {"min": 0.008, "max": 0.090},
|
|
||||||
"1080p": {"min": 0.05, "max": 0.15},
|
|
||||||
"2k/qhd": {"min": 0.038, "max": 0.225},
|
|
||||||
"4k/uhd": {"min": 0.056, "max": 0.338}
|
|
||||||
};
|
|
||||||
$prices := $contains($m, "ultra hd") ? $ultra_hd_model_prices :
|
|
||||||
$contains($m, "generative") ? $generative_model_prices :
|
|
||||||
$standard_model_prices;
|
|
||||||
$price := $lookup($prices, $res);
|
|
||||||
{
|
|
||||||
"type": "range_usd",
|
|
||||||
"min_usd": $price.min,
|
|
||||||
"max_usd": $price.max,
|
|
||||||
"format": {"approximate": true, "suffix": "/second"}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
""",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def execute(
|
|
||||||
cls,
|
|
||||||
model: InputVideoModel,
|
|
||||||
video: Input.Video,
|
|
||||||
) -> IO.NodeOutput:
|
|
||||||
validate_video_duration(video, min_duration=0.5, max_duration=60 * 60)
|
|
||||||
resolution = model["resolution"]
|
|
||||||
src_width, src_height = video.get_dimensions()
|
|
||||||
|
|
||||||
if resolution == "original":
|
|
||||||
output_width = src_width
|
|
||||||
output_height = src_height
|
|
||||||
else:
|
|
||||||
if src_width == src_height:
|
|
||||||
target_size = RESOLUTION_SQUARE_MAP[resolution]
|
|
||||||
if target_size < src_width:
|
|
||||||
raise ValueError(
|
|
||||||
f"Selected resolution {resolution} ({target_size}x{target_size}) is smaller than "
|
|
||||||
f"the input video ({src_width}x{src_height}). Please select a higher resolution or 'original'."
|
|
||||||
)
|
|
||||||
output_width = target_size
|
|
||||||
output_height = target_size
|
|
||||||
else:
|
|
||||||
min_dimension = min(src_width, src_height)
|
|
||||||
target_size = RESOLUTION_TARGET_MAP[resolution]
|
|
||||||
if target_size < min_dimension:
|
|
||||||
raise ValueError(
|
|
||||||
f"Selected resolution {resolution} ({target_size}p) is smaller than "
|
|
||||||
f"the input video's shorter dimension ({min_dimension}p). "
|
|
||||||
f"Please select a higher resolution or 'original'."
|
|
||||||
)
|
|
||||||
if src_width > src_height:
|
|
||||||
output_height = target_size
|
|
||||||
output_width = int(target_size * (src_width / src_height))
|
|
||||||
else:
|
|
||||||
output_width = target_size
|
|
||||||
output_height = int(target_size * (src_height / src_width))
|
|
||||||
initial_res = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path="/proxy/hitpaw/api/video-enhancer", method="POST"),
|
|
||||||
response_model=TaskCreateResponse,
|
|
||||||
data=VideoEnhanceTaskCreateRequest(
|
|
||||||
video_url=await upload_video_to_comfyapi(cls, video),
|
|
||||||
resolution=[output_width, output_height],
|
|
||||||
original_resolution=[src_width, src_height],
|
|
||||||
model_name=VIDEO_MODELS_MODELS_MAP[model["model"]],
|
|
||||||
),
|
|
||||||
wait_label="Creating task",
|
|
||||||
final_label_on_success="Task created",
|
|
||||||
)
|
|
||||||
request_price = initial_res.data.consume_coins / 1000
|
|
||||||
if initial_res.code != 200:
|
|
||||||
raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
|
|
||||||
final_response = await poll_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
|
|
||||||
data=TaskStatusPollRequest(job_id=initial_res.data.job_id),
|
|
||||||
response_model=TaskStatusResponse,
|
|
||||||
status_extractor=lambda x: x.data.status,
|
|
||||||
price_extractor=lambda x: request_price,
|
|
||||||
poll_interval=10.0,
|
|
||||||
max_poll_attempts=320,
|
|
||||||
)
|
|
||||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url))
|
|
||||||
|
|
||||||
|
|
||||||
class HitPawExtension(ComfyExtension):
|
|
||||||
@override
|
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
|
||||||
return [
|
|
||||||
HitPawGeneralImageEnhance,
|
|
||||||
HitPawVideoEnhance,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> HitPawExtension:
|
|
||||||
return HitPawExtension()
|
|
||||||
@ -94,7 +94,7 @@ async def upload_image_to_comfyapi(
|
|||||||
*,
|
*,
|
||||||
mime_type: str | None = None,
|
mime_type: str | None = None,
|
||||||
wait_label: str | None = "Uploading",
|
wait_label: str | None = "Uploading",
|
||||||
total_pixels: int | None = 2048 * 2048,
|
total_pixels: int = 2048 * 2048,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Uploads a single image to ComfyUI API and returns its download URL."""
|
"""Uploads a single image to ComfyUI API and returns its download URL."""
|
||||||
return (
|
return (
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user