mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
ad4b959d7e
@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
|
|||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||||
|
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
|
|||||||
@ -276,6 +276,9 @@ class ModelPatcher:
|
|||||||
self.size = comfy.model_management.module_size(self.model)
|
self.size = comfy.model_management.module_size(self.model)
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
|
def get_ram_usage(self):
|
||||||
|
return self.model_size()
|
||||||
|
|
||||||
def loaded_size(self):
|
def loaded_size(self):
|
||||||
return self.model.model_loaded_weight_memory
|
return self.model.model_loaded_weight_memory
|
||||||
|
|
||||||
@ -655,6 +658,7 @@ class ModelPatcher:
|
|||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
|
lowvram_mem_counter = 0
|
||||||
loading = self._load_list()
|
loading = self._load_list()
|
||||||
|
|
||||||
load_completely = []
|
load_completely = []
|
||||||
@ -675,6 +679,7 @@ class ModelPatcher:
|
|||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
|
lowvram_mem_counter += module_mem
|
||||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -748,10 +753,10 @@ class ModelPatcher:
|
|||||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
if full_load:
|
if full_load:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
|
|||||||
@ -421,14 +421,18 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
|
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||||
|
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||||
|
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight)
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
|
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
||||||
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
|
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
uncast_bias_weight(self, w, bias, offload_stream)
|
uncast_bias_weight(self, w, bias, offload_stream)
|
||||||
|
|||||||
@ -357,9 +357,10 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
scale = torch.tensor(scale)
|
scale = torch.tensor(scale)
|
||||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
lp_amax = torch.finfo(dtype).max
|
|
||||||
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
||||||
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
|
||||||
|
# lp_amax = torch.finfo(dtype).max
|
||||||
|
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||||
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
layout_params = {
|
layout_params = {
|
||||||
|
|||||||
14
comfy/sd.py
14
comfy/sd.py
@ -143,6 +143,9 @@ class CLIP:
|
|||||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def get_ram_usage(self):
|
||||||
|
return self.patcher.get_ram_usage()
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||||
|
|
||||||
@ -293,6 +296,7 @@ class VAE:
|
|||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
self.disable_offload = False
|
self.disable_offload = False
|
||||||
self.not_video = False
|
self.not_video = False
|
||||||
|
self.size = None
|
||||||
|
|
||||||
self.downscale_index_formula = None
|
self.downscale_index_formula = None
|
||||||
self.upscale_index_formula = None
|
self.upscale_index_formula = None
|
||||||
@ -595,6 +599,16 @@ class VAE:
|
|||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
self.model_size()
|
||||||
|
|
||||||
|
def model_size(self):
|
||||||
|
if self.size is not None:
|
||||||
|
return self.size
|
||||||
|
self.size = comfy.model_management.module_size(self.first_stage_model)
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
def get_ram_usage(self):
|
||||||
|
return self.model_size()
|
||||||
|
|
||||||
def throw_exception_if_invalid(self):
|
def throw_exception_if_invalid(self):
|
||||||
if self.first_stage_model is None:
|
if self.first_stage_model is None:
|
||||||
|
|||||||
@ -1,15 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from typing import Optional, Union
|
from typing import Union
|
||||||
from comfy.utils import common_upscale
|
|
||||||
from server import PromptServer
|
from server import PromptServer
|
||||||
from comfy.cli_args import args
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import torch
|
import torch
|
||||||
import math
|
|
||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
@ -60,85 +57,6 @@ async def validate_and_cast_response(
|
|||||||
return torch.stack(image_tensors, dim=0)
|
return torch.stack(image_tensors, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def validate_aspect_ratio(
|
|
||||||
aspect_ratio: str,
|
|
||||||
minimum_ratio: float,
|
|
||||||
maximum_ratio: float,
|
|
||||||
minimum_ratio_str: str,
|
|
||||||
maximum_ratio_str: str,
|
|
||||||
) -> float:
|
|
||||||
"""Validates and casts an aspect ratio string to a float.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
aspect_ratio: The aspect ratio string to validate.
|
|
||||||
minimum_ratio: The minimum aspect ratio.
|
|
||||||
maximum_ratio: The maximum aspect ratio.
|
|
||||||
minimum_ratio_str: The minimum aspect ratio string.
|
|
||||||
maximum_ratio_str: The maximum aspect ratio string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The validated and cast aspect ratio.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If the aspect ratio is not valid.
|
|
||||||
"""
|
|
||||||
# get ratio values
|
|
||||||
numbers = aspect_ratio.split(":")
|
|
||||||
if len(numbers) != 2:
|
|
||||||
raise TypeError(
|
|
||||||
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
numerator = int(numbers[0])
|
|
||||||
denominator = int(numbers[1])
|
|
||||||
except ValueError as exc:
|
|
||||||
raise TypeError(
|
|
||||||
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
|
|
||||||
) from exc
|
|
||||||
calculated_ratio = numerator / denominator
|
|
||||||
# if not close to minimum and maximum, check bounds
|
|
||||||
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
|
|
||||||
calculated_ratio, maximum_ratio
|
|
||||||
):
|
|
||||||
if calculated_ratio < minimum_ratio:
|
|
||||||
raise TypeError(
|
|
||||||
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
|
||||||
)
|
|
||||||
if calculated_ratio > maximum_ratio:
|
|
||||||
raise TypeError(
|
|
||||||
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
|
||||||
)
|
|
||||||
return aspect_ratio
|
|
||||||
|
|
||||||
|
|
||||||
async def download_url_to_bytesio(
|
|
||||||
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
|
|
||||||
) -> BytesIO:
|
|
||||||
"""Downloads content from a URL using requests and returns it as BytesIO.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: The URL to download.
|
|
||||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BytesIO object containing the downloaded content.
|
|
||||||
"""
|
|
||||||
headers = {}
|
|
||||||
if url.startswith("/proxy/"):
|
|
||||||
url = str(args.comfy_api_base).rstrip("/") + url
|
|
||||||
auth_token = auth_kwargs.get("auth_token")
|
|
||||||
comfy_api_key = auth_kwargs.get("comfy_api_key")
|
|
||||||
if auth_token:
|
|
||||||
headers["Authorization"] = f"Bearer {auth_token}"
|
|
||||||
elif comfy_api_key:
|
|
||||||
headers["X-API-KEY"] = comfy_api_key
|
|
||||||
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
|
||||||
async with session.get(url, headers=headers) as resp:
|
|
||||||
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
|
||||||
return BytesIO(await resp.read())
|
|
||||||
|
|
||||||
|
|
||||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||||
"""Converts a text file to a base64 string."""
|
"""Converts a text file to a base64 string."""
|
||||||
with open(filepath, "rb") as f:
|
with open(filepath, "rb") as f:
|
||||||
@ -153,28 +71,3 @@ def text_filepath_to_data_uri(filepath: str) -> str:
|
|||||||
if mime_type is None:
|
if mime_type is None:
|
||||||
mime_type = "application/octet-stream"
|
mime_type = "application/octet-stream"
|
||||||
return f"data:{mime_type};base64,{base64_string}"
|
return f"data:{mime_type};base64,{base64_string}"
|
||||||
|
|
||||||
|
|
||||||
def resize_mask_to_image(
|
|
||||||
mask: torch.Tensor,
|
|
||||||
image: torch.Tensor,
|
|
||||||
upscale_method="nearest-exact",
|
|
||||||
crop="disabled",
|
|
||||||
allow_gradient=True,
|
|
||||||
add_channel_dim=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
|
||||||
"""
|
|
||||||
_, H, W, _ = image.shape
|
|
||||||
mask = mask.unsqueeze(-1)
|
|
||||||
mask = mask.movedim(-1, 1)
|
|
||||||
mask = common_upscale(
|
|
||||||
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
|
|
||||||
)
|
|
||||||
mask = mask.movedim(1, -1)
|
|
||||||
if not add_channel_dim:
|
|
||||||
mask = mask.squeeze(-1)
|
|
||||||
if not allow_gradient:
|
|
||||||
mask = (mask > 0.5).float()
|
|
||||||
return mask
|
|
||||||
|
|||||||
@ -5,10 +5,6 @@ import torch
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
resize_mask_to_image,
|
|
||||||
validate_aspect_ratio,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis.bfl_api import (
|
from comfy_api_nodes.apis.bfl_api import (
|
||||||
BFLFluxExpandImageRequest,
|
BFLFluxExpandImageRequest,
|
||||||
BFLFluxFillImageRequest,
|
BFLFluxFillImageRequest,
|
||||||
@ -23,8 +19,10 @@ from comfy_api_nodes.util import (
|
|||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
download_url_to_image_tensor,
|
download_url_to_image_tensor,
|
||||||
poll_op,
|
poll_op,
|
||||||
|
resize_mask_to_image,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
|
validate_aspect_ratio_string,
|
||||||
validate_string,
|
validate_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -43,11 +41,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||||||
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MINIMUM_RATIO = 1 / 4
|
|
||||||
MAXIMUM_RATIO = 4 / 1
|
|
||||||
MINIMUM_RATIO_STR = "1:4"
|
|
||||||
MAXIMUM_RATIO_STR = "4:1"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
@ -112,16 +105,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_inputs(cls, aspect_ratio: str):
|
def validate_inputs(cls, aspect_ratio: str):
|
||||||
try:
|
validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
|
||||||
validate_aspect_ratio(
|
|
||||||
aspect_ratio,
|
|
||||||
minimum_ratio=cls.MINIMUM_RATIO,
|
|
||||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
|
||||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
|
||||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
return str(e)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -145,13 +129,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_upsampling=prompt_upsampling,
|
prompt_upsampling=prompt_upsampling,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
aspect_ratio=validate_aspect_ratio(
|
aspect_ratio=aspect_ratio,
|
||||||
aspect_ratio,
|
|
||||||
minimum_ratio=cls.MINIMUM_RATIO,
|
|
||||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
|
||||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
|
||||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
|
||||||
),
|
|
||||||
raw=raw,
|
raw=raw,
|
||||||
image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
|
image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
|
||||||
image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
|
image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
|
||||||
@ -180,11 +158,6 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||||||
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MINIMUM_RATIO = 1 / 4
|
|
||||||
MAXIMUM_RATIO = 4 / 1
|
|
||||||
MINIMUM_RATIO_STR = "1:4"
|
|
||||||
MAXIMUM_RATIO_STR = "4:1"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
@ -261,13 +234,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
|||||||
seed=0,
|
seed=0,
|
||||||
prompt_upsampling=False,
|
prompt_upsampling=False,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
aspect_ratio = validate_aspect_ratio(
|
validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
|
||||||
aspect_ratio,
|
|
||||||
minimum_ratio=cls.MINIMUM_RATIO,
|
|
||||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
|
||||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
|
||||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
|
||||||
)
|
|
||||||
if input_image is None:
|
if input_image is None:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
initial_response = await sync_op(
|
initial_response = await sync_op(
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from comfy_api_nodes.util import (
|
|||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
validate_image_aspect_ratio_range,
|
validate_image_aspect_ratio,
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
validate_string,
|
validate_string,
|
||||||
)
|
)
|
||||||
@ -403,7 +403,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
|||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
raise ValueError("Exactly one input image is required.")
|
raise ValueError("Exactly one input image is required.")
|
||||||
validate_image_aspect_ratio_range(image, (1, 3), (3, 1))
|
validate_image_aspect_ratio(image, (1, 3), (3, 1))
|
||||||
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
|
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
|
||||||
payload = Image2ImageTaskCreationRequest(
|
payload = Image2ImageTaskCreationRequest(
|
||||||
model=model,
|
model=model,
|
||||||
@ -565,7 +565,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
reference_images_urls = []
|
reference_images_urls = []
|
||||||
if n_input_images:
|
if n_input_images:
|
||||||
for i in image:
|
for i in image:
|
||||||
validate_image_aspect_ratio_range(i, (1, 3), (3, 1))
|
validate_image_aspect_ratio(i, (1, 3), (3, 1))
|
||||||
reference_images_urls = await upload_images_to_comfyapi(
|
reference_images_urls = await upload_images_to_comfyapi(
|
||||||
cls,
|
cls,
|
||||||
image,
|
image,
|
||||||
@ -798,7 +798,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
|||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
||||||
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||||
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||||
|
|
||||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
|
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
|
||||||
prompt = (
|
prompt = (
|
||||||
@ -923,7 +923,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
|||||||
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
||||||
for i in (first_frame, last_frame):
|
for i in (first_frame, last_frame):
|
||||||
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||||
validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||||
|
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
cls,
|
cls,
|
||||||
@ -1045,7 +1045,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||||||
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"])
|
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"])
|
||||||
for image in images:
|
for image in images:
|
||||||
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||||
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||||
|
|
||||||
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
|
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
|
||||||
prompt = (
|
prompt = (
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
|
|||||||
IdeogramV3Request,
|
IdeogramV3Request,
|
||||||
IdeogramV3EditRequest,
|
IdeogramV3EditRequest,
|
||||||
)
|
)
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
bytesio_to_image_tensor,
|
||||||
SynchronousOperation,
|
download_url_as_bytesio,
|
||||||
)
|
|
||||||
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
download_url_to_bytesio,
|
|
||||||
resize_mask_to_image,
|
resize_mask_to_image,
|
||||||
|
sync_op,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import bytesio_to_image_tensor
|
|
||||||
from server import PromptServer
|
|
||||||
|
|
||||||
V1_V1_RES_MAP = {
|
V1_V1_RES_MAP = {
|
||||||
"Auto":"AUTO",
|
"Auto":"AUTO",
|
||||||
@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
|
|||||||
|
|
||||||
for image_url in image_urls:
|
for image_url in image_urls:
|
||||||
# Using functions from apinode_utils.py to handle downloading and processing
|
# Using functions from apinode_utils.py to handle downloading and processing
|
||||||
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
|
image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
|
||||||
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||||
image_tensors.append(img_tensor)
|
image_tensors.append(img_tensor)
|
||||||
|
|
||||||
@ -233,19 +227,6 @@ async def download_and_process_images(image_urls):
|
|||||||
return stacked_tensors
|
return stacked_tensors
|
||||||
|
|
||||||
|
|
||||||
def display_image_urls_on_node(image_urls, node_id):
|
|
||||||
if node_id and image_urls:
|
|
||||||
if len(image_urls) == 1:
|
|
||||||
PromptServer.instance.send_progress_text(
|
|
||||||
f"Generated Image URL:\n{image_urls[0]}", node_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
urls_text = "Generated Image URLs:\n" + "\n".join(
|
|
||||||
f"{i+1}. {url}" for i, url in enumerate(image_urls)
|
|
||||||
)
|
|
||||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
|
||||||
|
|
||||||
|
|
||||||
class IdeogramV1(IO.ComfyNode):
|
class IdeogramV1(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -334,44 +315,30 @@ class IdeogramV1(IO.ComfyNode):
|
|||||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||||
model = "V_1_TURBO" if turbo else "V_1"
|
model = "V_1_TURBO" if turbo else "V_1"
|
||||||
|
|
||||||
auth = {
|
response = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||||
}
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/ideogram/generate",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=IdeogramGenerateRequest,
|
|
||||||
response_model=IdeogramGenerateResponse,
|
response_model=IdeogramGenerateResponse,
|
||||||
),
|
data=IdeogramGenerateRequest(
|
||||||
request=IdeogramGenerateRequest(
|
|
||||||
image_request=ImageRequest(
|
image_request=ImageRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
num_images=num_images,
|
num_images=num_images,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
||||||
magic_prompt_option=(
|
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
|
||||||
),
|
|
||||||
negative_prompt=negative_prompt if negative_prompt else None,
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await operation.execute()
|
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
raise Exception("No images were generated in the response")
|
raise Exception("No images were generated in the response")
|
||||||
|
|
||||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
|
||||||
if not image_urls:
|
if not image_urls:
|
||||||
raise Exception("No image URLs were generated in the response")
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
|
||||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
|
||||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||||
|
|
||||||
|
|
||||||
@ -500,18 +467,11 @@ class IdeogramV2(IO.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||||
|
|
||||||
auth = {
|
response = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||||
}
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/ideogram/generate",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=IdeogramGenerateRequest,
|
|
||||||
response_model=IdeogramGenerateResponse,
|
response_model=IdeogramGenerateResponse,
|
||||||
),
|
data=IdeogramGenerateRequest(
|
||||||
request=IdeogramGenerateRequest(
|
|
||||||
image_request=ImageRequest(
|
image_request=ImageRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
@ -519,28 +479,20 @@ class IdeogramV2(IO.ComfyNode):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
aspect_ratio=final_aspect_ratio,
|
aspect_ratio=final_aspect_ratio,
|
||||||
resolution=final_resolution,
|
resolution=final_resolution,
|
||||||
magic_prompt_option=(
|
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
|
||||||
),
|
|
||||||
style_type=style_type if style_type != "NONE" else None,
|
style_type=style_type if style_type != "NONE" else None,
|
||||||
negative_prompt=negative_prompt if negative_prompt else None,
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
color_palette=color_palette if color_palette else None,
|
color_palette=color_palette if color_palette else None,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await operation.execute()
|
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
raise Exception("No images were generated in the response")
|
raise Exception("No images were generated in the response")
|
||||||
|
|
||||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
|
||||||
if not image_urls:
|
if not image_urls:
|
||||||
raise Exception("No image URLs were generated in the response")
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
|
||||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
|
||||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||||
|
|
||||||
|
|
||||||
@ -656,10 +608,6 @@ class IdeogramV3(IO.ComfyNode):
|
|||||||
character_image=None,
|
character_image=None,
|
||||||
character_mask=None,
|
character_mask=None,
|
||||||
):
|
):
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
if rendering_speed == "BALANCED": # for backward compatibility
|
if rendering_speed == "BALANCED": # for backward compatibility
|
||||||
rendering_speed = "DEFAULT"
|
rendering_speed = "DEFAULT"
|
||||||
|
|
||||||
@ -694,9 +642,6 @@ class IdeogramV3(IO.ComfyNode):
|
|||||||
|
|
||||||
# Check if both image and mask are provided for editing mode
|
# Check if both image and mask are provided for editing mode
|
||||||
if image is not None and mask is not None:
|
if image is not None and mask is not None:
|
||||||
# Edit mode
|
|
||||||
path = "/proxy/ideogram/ideogram-v3/edit"
|
|
||||||
|
|
||||||
# Process image and mask
|
# Process image and mask
|
||||||
input_tensor = image.squeeze().cpu()
|
input_tensor = image.squeeze().cpu()
|
||||||
# Resize mask to match image dimension
|
# Resize mask to match image dimension
|
||||||
@ -749,27 +694,20 @@ class IdeogramV3(IO.ComfyNode):
|
|||||||
if character_mask_binary:
|
if character_mask_binary:
|
||||||
files["character_mask_binary"] = character_mask_binary
|
files["character_mask_binary"] = character_mask_binary
|
||||||
|
|
||||||
# Execute the operation for edit mode
|
response = await sync_op(
|
||||||
operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
|
||||||
path=path,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=IdeogramV3EditRequest,
|
|
||||||
response_model=IdeogramGenerateResponse,
|
response_model=IdeogramGenerateResponse,
|
||||||
),
|
data=edit_request,
|
||||||
request=edit_request,
|
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif image is not None or mask is not None:
|
elif image is not None or mask is not None:
|
||||||
# If only one of image or mask is provided, raise an error
|
# If only one of image or mask is provided, raise an error
|
||||||
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
|
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
|
||||||
else:
|
else:
|
||||||
# Generation mode
|
|
||||||
path = "/proxy/ideogram/ideogram-v3/generate"
|
|
||||||
|
|
||||||
# Create generation request
|
# Create generation request
|
||||||
gen_request = IdeogramV3Request(
|
gen_request = IdeogramV3Request(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -800,32 +738,22 @@ class IdeogramV3(IO.ComfyNode):
|
|||||||
if files:
|
if files:
|
||||||
gen_request.style_type = "AUTO"
|
gen_request.style_type = "AUTO"
|
||||||
|
|
||||||
# Execute the operation for generation mode
|
response = await sync_op(
|
||||||
operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
|
||||||
path=path,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=IdeogramV3Request,
|
|
||||||
response_model=IdeogramGenerateResponse,
|
response_model=IdeogramGenerateResponse,
|
||||||
),
|
data=gen_request,
|
||||||
request=gen_request,
|
|
||||||
files=files if files else None,
|
files=files if files else None,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the operation and process response
|
|
||||||
response = await operation.execute()
|
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
raise Exception("No images were generated in the response")
|
raise Exception("No images were generated in the response")
|
||||||
|
|
||||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
|
||||||
if not image_urls:
|
if not image_urls:
|
||||||
raise Exception("No image URLs were generated in the response")
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
|
||||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
|
||||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||||
|
|
||||||
|
|
||||||
@ -838,5 +766,6 @@ class IdeogramExtension(ComfyExtension):
|
|||||||
IdeogramV3,
|
IdeogramV3,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> IdeogramExtension:
|
async def comfy_entrypoint() -> IdeogramExtension:
|
||||||
return IdeogramExtension()
|
return IdeogramExtension()
|
||||||
|
|||||||
@ -282,7 +282,7 @@ def validate_input_image(image: torch.Tensor) -> None:
|
|||||||
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
|
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
|
||||||
"""
|
"""
|
||||||
validate_image_dimensions(image, min_width=300, min_height=300)
|
validate_image_dimensions(image, min_width=300, min_height=300)
|
||||||
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5)
|
validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1))
|
||||||
|
|
||||||
|
|
||||||
def get_video_from_response(response) -> KlingVideoResult:
|
def get_video_from_response(response) -> KlingVideoResult:
|
||||||
|
|||||||
@ -225,7 +225,7 @@ class OpenAIDalle2(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
files=(
|
files=(
|
||||||
{
|
{
|
||||||
"image": img_binary,
|
"image": ("image.png", img_binary, "image/png"),
|
||||||
}
|
}
|
||||||
if img_binary
|
if img_binary
|
||||||
else None
|
else None
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from inspect import cleandoc
|
import torch
|
||||||
from typing import Optional
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from io import BytesIO
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.apis.pixverse_api import (
|
from comfy_api_nodes.apis.pixverse_api import (
|
||||||
PixverseTextVideoRequest,
|
PixverseTextVideoRequest,
|
||||||
PixverseImageVideoRequest,
|
PixverseImageVideoRequest,
|
||||||
@ -17,53 +16,30 @@ from comfy_api_nodes.apis.pixverse_api import (
|
|||||||
PixverseIO,
|
PixverseIO,
|
||||||
pixverse_templates,
|
pixverse_templates,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
download_url_to_video_output,
|
||||||
SynchronousOperation,
|
poll_op,
|
||||||
PollingOperation,
|
sync_op,
|
||||||
EmptyRequest,
|
tensor_to_bytesio,
|
||||||
|
validate_string,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import validate_string, tensor_to_bytesio
|
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import aiohttp
|
|
||||||
|
|
||||||
|
|
||||||
AVERAGE_DURATION_T2V = 32
|
AVERAGE_DURATION_T2V = 32
|
||||||
AVERAGE_DURATION_I2V = 30
|
AVERAGE_DURATION_I2V = 30
|
||||||
AVERAGE_DURATION_T2T = 52
|
AVERAGE_DURATION_T2T = 52
|
||||||
|
|
||||||
|
|
||||||
def get_video_url_from_response(
|
async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
|
||||||
response: PixverseGenerationStatusResponse,
|
response_upload = await sync_op(
|
||||||
) -> Optional[str]:
|
cls,
|
||||||
if response.Resp is None or response.Resp.url is None:
|
ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
|
||||||
return None
|
|
||||||
return str(response.Resp.url)
|
|
||||||
|
|
||||||
|
|
||||||
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
|
||||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/pixverse/image/upload",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PixverseImageUploadResponse,
|
response_model=PixverseImageUploadResponse,
|
||||||
),
|
|
||||||
request=EmptyRequest(),
|
|
||||||
files={"image": tensor_to_bytesio(image)},
|
files={"image": tensor_to_bytesio(image)},
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_upload: PixverseImageUploadResponse = await operation.execute()
|
|
||||||
|
|
||||||
if response_upload.Resp is None:
|
if response_upload.Resp is None:
|
||||||
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
|
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
|
||||||
|
|
||||||
return response_upload.Resp.img_id
|
return response_upload.Resp.img_id
|
||||||
|
|
||||||
|
|
||||||
@ -93,17 +69,13 @@ class PixverseTemplateNode(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class PixverseTextToVideoNode(IO.ComfyNode):
|
class PixverseTextToVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos based on prompt and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="PixverseTextToVideoNode",
|
node_id="PixverseTextToVideoNode",
|
||||||
display_name="PixVerse Text to Video",
|
display_name="PixVerse Text to Video",
|
||||||
category="api node/video/PixVerse",
|
category="api node/video/PixVerse",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos based on prompt and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -170,7 +142,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
|||||||
negative_prompt: str = None,
|
negative_prompt: str = None,
|
||||||
pixverse_template: int = None,
|
pixverse_template: int = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||||
# 1080p is limited to 5 seconds duration
|
# 1080p is limited to 5 seconds duration
|
||||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
if quality == PixverseQuality.res_1080p:
|
if quality == PixverseQuality.res_1080p:
|
||||||
@ -179,18 +151,11 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
|||||||
elif duration_seconds != PixverseDuration.dur_5:
|
elif duration_seconds != PixverseDuration.dur_5:
|
||||||
motion_mode = PixverseMotionMode.normal
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
|
||||||
}
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/pixverse/video/text/generate",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=PixverseTextVideoRequest,
|
|
||||||
response_model=PixverseVideoResponse,
|
response_model=PixverseVideoResponse,
|
||||||
),
|
data=PixverseTextVideoRequest(
|
||||||
request=PixverseTextVideoRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
@ -200,20 +165,14 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
|||||||
template_id=pixverse_template,
|
template_id=pixverse_template,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.Resp is None:
|
if response_api.Resp is None:
|
||||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PixverseGenerationStatusResponse,
|
response_model=PixverseGenerationStatusResponse,
|
||||||
),
|
|
||||||
completed_statuses=[PixverseStatus.successful],
|
completed_statuses=[PixverseStatus.successful],
|
||||||
failed_statuses=[
|
failed_statuses=[
|
||||||
PixverseStatus.contents_moderation,
|
PixverseStatus.contents_moderation,
|
||||||
@ -221,30 +180,19 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
|||||||
PixverseStatus.deleted,
|
PixverseStatus.deleted,
|
||||||
],
|
],
|
||||||
status_extractor=lambda x: x.Resp.status,
|
status_extractor=lambda x: x.Resp.status,
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
result_url_extractor=get_video_url_from_response,
|
|
||||||
estimated_duration=AVERAGE_DURATION_T2V,
|
estimated_duration=AVERAGE_DURATION_T2V,
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.Resp.url) as vid_response:
|
|
||||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
|
|
||||||
class PixverseImageToVideoNode(IO.ComfyNode):
|
class PixverseImageToVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos based on prompt and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="PixverseImageToVideoNode",
|
node_id="PixverseImageToVideoNode",
|
||||||
display_name="PixVerse Image to Video",
|
display_name="PixVerse Image to Video",
|
||||||
category="api node/video/PixVerse",
|
category="api node/video/PixVerse",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos based on prompt and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
@ -309,11 +257,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
|||||||
pixverse_template: int = None,
|
pixverse_template: int = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
auth = {
|
img_id = await upload_image_to_pixverse(cls, image)
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
|
|
||||||
|
|
||||||
# 1080p is limited to 5 seconds duration
|
# 1080p is limited to 5 seconds duration
|
||||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
@ -323,14 +267,11 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
|||||||
elif duration_seconds != PixverseDuration.dur_5:
|
elif duration_seconds != PixverseDuration.dur_5:
|
||||||
motion_mode = PixverseMotionMode.normal
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/pixverse/video/img/generate",
|
ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=PixverseImageVideoRequest,
|
|
||||||
response_model=PixverseVideoResponse,
|
response_model=PixverseVideoResponse,
|
||||||
),
|
data=PixverseImageVideoRequest(
|
||||||
request=PixverseImageVideoRequest(
|
|
||||||
img_id=img_id,
|
img_id=img_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
@ -340,20 +281,15 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
|||||||
template_id=pixverse_template,
|
template_id=pixverse_template,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.Resp is None:
|
if response_api.Resp is None:
|
||||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PixverseGenerationStatusResponse,
|
response_model=PixverseGenerationStatusResponse,
|
||||||
),
|
|
||||||
completed_statuses=[PixverseStatus.successful],
|
completed_statuses=[PixverseStatus.successful],
|
||||||
failed_statuses=[
|
failed_statuses=[
|
||||||
PixverseStatus.contents_moderation,
|
PixverseStatus.contents_moderation,
|
||||||
@ -361,30 +297,19 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
|||||||
PixverseStatus.deleted,
|
PixverseStatus.deleted,
|
||||||
],
|
],
|
||||||
status_extractor=lambda x: x.Resp.status,
|
status_extractor=lambda x: x.Resp.status,
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
result_url_extractor=get_video_url_from_response,
|
|
||||||
estimated_duration=AVERAGE_DURATION_I2V,
|
estimated_duration=AVERAGE_DURATION_I2V,
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.Resp.url) as vid_response:
|
|
||||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
|
|
||||||
class PixverseTransitionVideoNode(IO.ComfyNode):
|
class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos based on prompt and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> IO.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="PixverseTransitionVideoNode",
|
node_id="PixverseTransitionVideoNode",
|
||||||
display_name="PixVerse Transition Video",
|
display_name="PixVerse Transition Video",
|
||||||
category="api node/video/PixVerse",
|
category="api node/video/PixVerse",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos based on prompt and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("first_frame"),
|
IO.Image.Input("first_frame"),
|
||||||
IO.Image.Input("last_frame"),
|
IO.Image.Input("last_frame"),
|
||||||
@ -445,12 +370,8 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
|||||||
negative_prompt: str = None,
|
negative_prompt: str = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
auth = {
|
first_frame_id = await upload_image_to_pixverse(cls, first_frame)
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
last_frame_id = await upload_image_to_pixverse(cls, last_frame)
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
|
|
||||||
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
|
|
||||||
|
|
||||||
# 1080p is limited to 5 seconds duration
|
# 1080p is limited to 5 seconds duration
|
||||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
@ -460,14 +381,11 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
|||||||
elif duration_seconds != PixverseDuration.dur_5:
|
elif duration_seconds != PixverseDuration.dur_5:
|
||||||
motion_mode = PixverseMotionMode.normal
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/pixverse/video/transition/generate",
|
ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=PixverseTransitionVideoRequest,
|
|
||||||
response_model=PixverseVideoResponse,
|
response_model=PixverseVideoResponse,
|
||||||
),
|
data=PixverseTransitionVideoRequest(
|
||||||
request=PixverseTransitionVideoRequest(
|
|
||||||
first_frame_img=first_frame_id,
|
first_frame_img=first_frame_id,
|
||||||
last_frame_img=last_frame_id,
|
last_frame_img=last_frame_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -477,20 +395,15 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
|||||||
negative_prompt=negative_prompt if negative_prompt else None,
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.Resp is None:
|
if response_api.Resp is None:
|
||||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PixverseGenerationStatusResponse,
|
response_model=PixverseGenerationStatusResponse,
|
||||||
),
|
|
||||||
completed_statuses=[PixverseStatus.successful],
|
completed_statuses=[PixverseStatus.successful],
|
||||||
failed_statuses=[
|
failed_statuses=[
|
||||||
PixverseStatus.contents_moderation,
|
PixverseStatus.contents_moderation,
|
||||||
@ -498,16 +411,9 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
|||||||
PixverseStatus.deleted,
|
PixverseStatus.deleted,
|
||||||
],
|
],
|
||||||
status_extractor=lambda x: x.Resp.status,
|
status_extractor=lambda x: x.Resp.status,
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
result_url_extractor=get_video_url_from_response,
|
|
||||||
estimated_duration=AVERAGE_DURATION_T2V,
|
estimated_duration=AVERAGE_DURATION_T2V,
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.Resp.url) as vid_response:
|
|
||||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
|
|
||||||
class PixVerseExtension(ComfyExtension):
|
class PixVerseExtension(ComfyExtension):
|
||||||
|
|||||||
@ -8,9 +8,6 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
resize_mask_to_image,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis.recraft_api import (
|
from comfy_api_nodes.apis.recraft_api import (
|
||||||
RecraftColor,
|
RecraftColor,
|
||||||
RecraftColorChain,
|
RecraftColorChain,
|
||||||
@ -28,6 +25,7 @@ from comfy_api_nodes.util import (
|
|||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
download_url_as_bytesio,
|
download_url_as_bytesio,
|
||||||
|
resize_mask_to_image,
|
||||||
sync_op,
|
sync_op,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
validate_string,
|
validate_string,
|
||||||
|
|||||||
@ -200,7 +200,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
|||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1)
|
validate_string(prompt, min_length=1)
|
||||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||||
|
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
cls,
|
cls,
|
||||||
@ -290,7 +290,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
|||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1)
|
validate_string(prompt, min_length=1)
|
||||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||||
|
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
cls,
|
cls,
|
||||||
@ -390,8 +390,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
|||||||
validate_string(prompt, min_length=1)
|
validate_string(prompt, min_length=1)
|
||||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||||
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
|
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
|
||||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||||
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
|
||||||
|
|
||||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
@ -475,7 +475,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
|||||||
reference_images = None
|
reference_images = None
|
||||||
if reference_image is not None:
|
if reference_image is not None:
|
||||||
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
|
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
|
||||||
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
cls,
|
cls,
|
||||||
reference_image,
|
reference_image,
|
||||||
|
|||||||
@ -14,9 +14,9 @@ from comfy_api_nodes.util import (
|
|||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
validate_aspect_ratio_closeness,
|
validate_image_aspect_ratio,
|
||||||
validate_image_aspect_ratio_range,
|
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
|
validate_images_aspect_ratio_closeness,
|
||||||
)
|
)
|
||||||
|
|
||||||
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
||||||
@ -114,7 +114,7 @@ async def execute_task(
|
|||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
|
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
|
||||||
response_model=TaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
status_extractor=lambda r: r.state.value,
|
status_extractor=lambda r: r.state,
|
||||||
estimated_duration=estimated_duration,
|
estimated_duration=estimated_duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -307,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
|||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if get_number_of_images(image) > 1:
|
if get_number_of_images(image) > 1:
|
||||||
raise ValueError("Only one input image is allowed.")
|
raise ValueError("Only one input image is allowed.")
|
||||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||||
payload = TaskCreationRequest(
|
payload = TaskCreationRequest(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -423,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
|||||||
if a > 7:
|
if a > 7:
|
||||||
raise ValueError("Too many images, maximum allowed is 7.")
|
raise ValueError("Too many images, maximum allowed is 7.")
|
||||||
for image in images:
|
for image in images:
|
||||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||||
validate_image_dimensions(image, min_width=128, min_height=128)
|
validate_image_dimensions(image, min_width=128, min_height=128)
|
||||||
payload = TaskCreationRequest(
|
payload = TaskCreationRequest(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
@ -533,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
|||||||
resolution: str,
|
resolution: str,
|
||||||
movement_amplitude: str,
|
movement_amplitude: str,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||||
payload = TaskCreationRequest(
|
payload = TaskCreationRequest(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from .conversions import (
|
|||||||
downscale_image_tensor,
|
downscale_image_tensor,
|
||||||
image_tensor_pair_to_batch,
|
image_tensor_pair_to_batch,
|
||||||
pil_to_bytesio,
|
pil_to_bytesio,
|
||||||
|
resize_mask_to_image,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
tensor_to_pil,
|
tensor_to_pil,
|
||||||
@ -34,12 +35,12 @@ from .upload_helpers import (
|
|||||||
)
|
)
|
||||||
from .validation_utils import (
|
from .validation_utils import (
|
||||||
get_number_of_images,
|
get_number_of_images,
|
||||||
validate_aspect_ratio_closeness,
|
validate_aspect_ratio_string,
|
||||||
validate_audio_duration,
|
validate_audio_duration,
|
||||||
validate_container_format_is_mp4,
|
validate_container_format_is_mp4,
|
||||||
validate_image_aspect_ratio,
|
validate_image_aspect_ratio,
|
||||||
validate_image_aspect_ratio_range,
|
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
|
validate_images_aspect_ratio_closeness,
|
||||||
validate_string,
|
validate_string,
|
||||||
validate_video_dimensions,
|
validate_video_dimensions,
|
||||||
validate_video_duration,
|
validate_video_duration,
|
||||||
@ -70,6 +71,7 @@ __all__ = [
|
|||||||
"downscale_image_tensor",
|
"downscale_image_tensor",
|
||||||
"image_tensor_pair_to_batch",
|
"image_tensor_pair_to_batch",
|
||||||
"pil_to_bytesio",
|
"pil_to_bytesio",
|
||||||
|
"resize_mask_to_image",
|
||||||
"tensor_to_base64_string",
|
"tensor_to_base64_string",
|
||||||
"tensor_to_bytesio",
|
"tensor_to_bytesio",
|
||||||
"tensor_to_pil",
|
"tensor_to_pil",
|
||||||
@ -77,12 +79,12 @@ __all__ = [
|
|||||||
"video_to_base64_string",
|
"video_to_base64_string",
|
||||||
# Validation utilities
|
# Validation utilities
|
||||||
"get_number_of_images",
|
"get_number_of_images",
|
||||||
"validate_aspect_ratio_closeness",
|
"validate_aspect_ratio_string",
|
||||||
"validate_audio_duration",
|
"validate_audio_duration",
|
||||||
"validate_container_format_is_mp4",
|
"validate_container_format_is_mp4",
|
||||||
"validate_image_aspect_ratio",
|
"validate_image_aspect_ratio",
|
||||||
"validate_image_aspect_ratio_range",
|
|
||||||
"validate_image_dimensions",
|
"validate_image_dimensions",
|
||||||
|
"validate_images_aspect_ratio_closeness",
|
||||||
"validate_string",
|
"validate_string",
|
||||||
"validate_video_dimensions",
|
"validate_video_dimensions",
|
||||||
"validate_video_duration",
|
"validate_video_duration",
|
||||||
|
|||||||
@ -430,3 +430,24 @@ def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
|
|||||||
wav = torch.cat(frames, dim=1) # [C, T]
|
wav = torch.cat(frames, dim=1) # [C, T]
|
||||||
wav = _f32_pcm(wav)
|
wav = _f32_pcm(wav)
|
||||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||||
|
|
||||||
|
|
||||||
|
def resize_mask_to_image(
|
||||||
|
mask: torch.Tensor,
|
||||||
|
image: torch.Tensor,
|
||||||
|
upscale_method="nearest-exact",
|
||||||
|
crop="disabled",
|
||||||
|
allow_gradient=True,
|
||||||
|
add_channel_dim=False,
|
||||||
|
):
|
||||||
|
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
|
||||||
|
_, height, width, _ = image.shape
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
mask = mask.movedim(-1, 1)
|
||||||
|
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
|
||||||
|
mask = mask.movedim(1, -1)
|
||||||
|
if not add_channel_dim:
|
||||||
|
mask = mask.squeeze(-1)
|
||||||
|
if not allow_gradient:
|
||||||
|
mask = (mask > 0.5).float()
|
||||||
|
return mask
|
||||||
|
|||||||
@ -37,63 +37,62 @@ def validate_image_dimensions(
|
|||||||
|
|
||||||
def validate_image_aspect_ratio(
|
def validate_image_aspect_ratio(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
min_aspect_ratio: Optional[float] = None,
|
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||||
max_aspect_ratio: Optional[float] = None,
|
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||||
):
|
|
||||||
width, height = get_image_dimensions(image)
|
|
||||||
aspect_ratio = width / height
|
|
||||||
|
|
||||||
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
|
|
||||||
raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}")
|
|
||||||
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
|
|
||||||
raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}")
|
|
||||||
|
|
||||||
|
|
||||||
def validate_image_aspect_ratio_range(
|
|
||||||
image: torch.Tensor,
|
|
||||||
min_ratio: tuple[float, float], # e.g. (1, 4)
|
|
||||||
max_ratio: tuple[float, float], # e.g. (4, 1)
|
|
||||||
*,
|
*,
|
||||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||||
) -> float:
|
) -> float:
|
||||||
a1, b1 = min_ratio
|
"""Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked."""
|
||||||
a2, b2 = max_ratio
|
|
||||||
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
|
|
||||||
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
|
|
||||||
lo, hi = (a1 / b1), (a2 / b2)
|
|
||||||
if lo > hi:
|
|
||||||
lo, hi = hi, lo
|
|
||||||
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
|
|
||||||
w, h = get_image_dimensions(image)
|
w, h = get_image_dimensions(image)
|
||||||
if w <= 0 or h <= 0:
|
if w <= 0 or h <= 0:
|
||||||
raise ValueError(f"Invalid image dimensions: {w}x{h}")
|
raise ValueError(f"Invalid image dimensions: {w}x{h}")
|
||||||
ar = w / h
|
ar = w / h
|
||||||
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
|
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
|
||||||
if not ok:
|
|
||||||
op = "<" if strict else "≤"
|
|
||||||
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
|
|
||||||
return ar
|
return ar
|
||||||
|
|
||||||
|
|
||||||
def validate_aspect_ratio_closeness(
|
def validate_images_aspect_ratio_closeness(
|
||||||
start_img,
|
first_image: torch.Tensor,
|
||||||
end_img,
|
second_image: torch.Tensor,
|
||||||
min_rel: float,
|
min_rel: float, # e.g. 0.8
|
||||||
max_rel: float,
|
max_rel: float, # e.g. 1.25
|
||||||
*,
|
*,
|
||||||
strict: bool = False, # True => exclusive, False => inclusive
|
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||||
) -> None:
|
) -> float:
|
||||||
w1, h1 = get_image_dimensions(start_img)
|
"""
|
||||||
w2, h2 = get_image_dimensions(end_img)
|
Validates that the two images' aspect ratios are 'close'.
|
||||||
|
The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1).
|
||||||
|
We require C <= limit, where limit = max(max_rel, 1.0 / min_rel).
|
||||||
|
|
||||||
|
Returns the computed closeness factor C.
|
||||||
|
"""
|
||||||
|
w1, h1 = get_image_dimensions(first_image)
|
||||||
|
w2, h2 = get_image_dimensions(second_image)
|
||||||
if min(w1, h1, w2, h2) <= 0:
|
if min(w1, h1, w2, h2) <= 0:
|
||||||
raise ValueError("Invalid image dimensions")
|
raise ValueError("Invalid image dimensions")
|
||||||
ar1 = w1 / h1
|
ar1 = w1 / h1
|
||||||
ar2 = w2 / h2
|
ar2 = w2 / h2
|
||||||
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
|
|
||||||
closeness = max(ar1, ar2) / min(ar1, ar2)
|
closeness = max(ar1, ar2) / min(ar1, ar2)
|
||||||
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
|
limit = max(max_rel, 1.0 / min_rel)
|
||||||
if (closeness >= limit) if strict else (closeness > limit):
|
if (closeness >= limit) if strict else (closeness > limit):
|
||||||
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.")
|
raise ValueError(
|
||||||
|
f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, "
|
||||||
|
f"allowed range {min_rel}–{max_rel} (limit {limit:.2g})."
|
||||||
|
)
|
||||||
|
return closeness
|
||||||
|
|
||||||
|
|
||||||
|
def validate_aspect_ratio_string(
|
||||||
|
aspect_ratio: str,
|
||||||
|
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||||
|
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||||
|
*,
|
||||||
|
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||||
|
) -> float:
|
||||||
|
"""Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio."""
|
||||||
|
ar = _parse_aspect_ratio_string(aspect_ratio)
|
||||||
|
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
|
||||||
|
return ar
|
||||||
|
|
||||||
|
|
||||||
def validate_video_dimensions(
|
def validate_video_dimensions(
|
||||||
@ -183,3 +182,49 @@ def validate_container_format_is_mp4(video: VideoInput) -> None:
|
|||||||
container_format = video.get_container_format()
|
container_format = video.get_container_format()
|
||||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||||
|
|
||||||
|
|
||||||
|
def _ratio_from_tuple(r: tuple[float, float]) -> float:
|
||||||
|
a, b = r
|
||||||
|
if a <= 0 or b <= 0:
|
||||||
|
raise ValueError(f"Ratios must be positive, got {a}:{b}.")
|
||||||
|
return a / b
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_ratio_bounds(
|
||||||
|
ar: float,
|
||||||
|
*,
|
||||||
|
min_ratio: Optional[tuple[float, float]] = None,
|
||||||
|
max_ratio: Optional[tuple[float, float]] = None,
|
||||||
|
strict: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
||||||
|
lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None
|
||||||
|
hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None
|
||||||
|
|
||||||
|
if lo is not None and hi is not None and lo > hi:
|
||||||
|
lo, hi = hi, lo # normalize order if caller swapped them
|
||||||
|
|
||||||
|
if lo is not None:
|
||||||
|
if (ar <= lo) if strict else (ar < lo):
|
||||||
|
op = "<" if strict else "≤"
|
||||||
|
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.")
|
||||||
|
if hi is not None:
|
||||||
|
if (ar >= hi) if strict else (ar > hi):
|
||||||
|
op = "<" if strict else "≤"
|
||||||
|
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_aspect_ratio_string(ar_str: str) -> float:
|
||||||
|
"""Parse 'X:Y' with integer parts into a positive float ratio X/Y."""
|
||||||
|
parts = ar_str.split(":")
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.")
|
||||||
|
try:
|
||||||
|
a = int(parts[0].strip())
|
||||||
|
b = int(parts[1].strip())
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc
|
||||||
|
if a <= 0 or b <= 0:
|
||||||
|
raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.")
|
||||||
|
return a / b
|
||||||
|
|||||||
@ -1,4 +1,9 @@
|
|||||||
|
import bisect
|
||||||
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
|
import psutil
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
from typing import Sequence, Mapping, Dict
|
from typing import Sequence, Mapping, Dict
|
||||||
from comfy_execution.graph import DynamicPrompt
|
from comfy_execution.graph import DynamicPrompt
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -188,6 +193,9 @@ class BasicCache:
|
|||||||
self._clean_cache()
|
self._clean_cache()
|
||||||
self._clean_subcaches()
|
self._clean_subcaches()
|
||||||
|
|
||||||
|
def poll(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
def _set_immediate(self, node_id, value):
|
def _set_immediate(self, node_id, value):
|
||||||
assert self.initialized
|
assert self.initialized
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
@ -276,6 +284,9 @@ class NullCache:
|
|||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def poll(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
def get(self, node_id):
|
def get(self, node_id):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -336,3 +347,75 @@ class LRUCache(BasicCache):
|
|||||||
self._mark_used(child_id)
|
self._mark_used(child_id)
|
||||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
||||||
|
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
||||||
|
|
||||||
|
RAM_CACHE_HYSTERESIS = 1.1
|
||||||
|
|
||||||
|
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
|
||||||
|
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
|
||||||
|
|
||||||
|
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
||||||
|
|
||||||
|
#Exponential bias towards evicting older workflows so garbage will be taken out
|
||||||
|
#in constantly changing setups.
|
||||||
|
|
||||||
|
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||||
|
|
||||||
|
class RAMPressureCache(LRUCache):
|
||||||
|
|
||||||
|
def __init__(self, key_class):
|
||||||
|
super().__init__(key_class, 0)
|
||||||
|
self.timestamps = {}
|
||||||
|
|
||||||
|
def clean_unused(self):
|
||||||
|
self._clean_subcaches()
|
||||||
|
|
||||||
|
def set(self, node_id, value):
|
||||||
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
|
super().set(node_id, value)
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
|
return super().get(node_id)
|
||||||
|
|
||||||
|
def poll(self, ram_headroom):
|
||||||
|
def _ram_gb():
|
||||||
|
return psutil.virtual_memory().available / (1024**3)
|
||||||
|
|
||||||
|
if _ram_gb() > ram_headroom:
|
||||||
|
return
|
||||||
|
gc.collect()
|
||||||
|
if _ram_gb() > ram_headroom:
|
||||||
|
return
|
||||||
|
|
||||||
|
clean_list = []
|
||||||
|
|
||||||
|
for key, (outputs, _), in self.cache.items():
|
||||||
|
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||||
|
|
||||||
|
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||||
|
def scan_list_for_ram_usage(outputs):
|
||||||
|
nonlocal ram_usage
|
||||||
|
for output in outputs:
|
||||||
|
if isinstance(output, list):
|
||||||
|
scan_list_for_ram_usage(output)
|
||||||
|
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||||
|
#score Tensors at a 50% discount for RAM usage as they are likely to
|
||||||
|
#be high value intermediates
|
||||||
|
ram_usage += (output.numel() * output.element_size()) * 0.5
|
||||||
|
elif hasattr(output, "get_ram_usage"):
|
||||||
|
ram_usage += output.get_ram_usage()
|
||||||
|
scan_list_for_ram_usage(outputs)
|
||||||
|
|
||||||
|
oom_score *= ram_usage
|
||||||
|
#In the case where we have no information on the node ram usage at all,
|
||||||
|
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||||
|
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||||
|
|
||||||
|
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
||||||
|
_, _, key = clean_list.pop()
|
||||||
|
del self.cache[key]
|
||||||
|
gc.collect()
|
||||||
|
|||||||
@ -209,10 +209,15 @@ class ExecutionList(TopologicalSort):
|
|||||||
self.execution_cache_listeners[from_node_id] = set()
|
self.execution_cache_listeners[from_node_id] = set()
|
||||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||||
|
|
||||||
def get_output_cache(self, from_node_id, to_node_id):
|
def get_cache(self, from_node_id, to_node_id):
|
||||||
if not to_node_id in self.execution_cache:
|
if not to_node_id in self.execution_cache:
|
||||||
return None
|
return None
|
||||||
return self.execution_cache[to_node_id].get(from_node_id)
|
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
#Write back to the main cache on touch.
|
||||||
|
self.output_cache.set(from_node_id, value)
|
||||||
|
return value
|
||||||
|
|
||||||
def cache_update(self, node_id, value):
|
def cache_update(self, node_id, value):
|
||||||
if node_id in self.execution_cache_listeners:
|
if node_id in self.execution_cache_listeners:
|
||||||
|
|||||||
77
execution.py
77
execution.py
@ -21,6 +21,7 @@ from comfy_execution.caching import (
|
|||||||
NullCache,
|
NullCache,
|
||||||
HierarchicalCache,
|
HierarchicalCache,
|
||||||
LRUCache,
|
LRUCache,
|
||||||
|
RAMPressureCache,
|
||||||
)
|
)
|
||||||
from comfy_execution.graph import (
|
from comfy_execution.graph import (
|
||||||
DynamicPrompt,
|
DynamicPrompt,
|
||||||
@ -88,49 +89,56 @@ class IsChangedCache:
|
|||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
|
||||||
|
class CacheEntry(NamedTuple):
|
||||||
|
ui: dict
|
||||||
|
outputs: list
|
||||||
|
|
||||||
|
|
||||||
class CacheType(Enum):
|
class CacheType(Enum):
|
||||||
CLASSIC = 0
|
CLASSIC = 0
|
||||||
LRU = 1
|
LRU = 1
|
||||||
NONE = 2
|
NONE = 2
|
||||||
|
RAM_PRESSURE = 3
|
||||||
|
|
||||||
|
|
||||||
class CacheSet:
|
class CacheSet:
|
||||||
def __init__(self, cache_type=None, cache_size=None):
|
def __init__(self, cache_type=None, cache_args={}):
|
||||||
if cache_type == CacheType.NONE:
|
if cache_type == CacheType.NONE:
|
||||||
self.init_null_cache()
|
self.init_null_cache()
|
||||||
logging.info("Disabling intermediate node cache.")
|
logging.info("Disabling intermediate node cache.")
|
||||||
|
elif cache_type == CacheType.RAM_PRESSURE:
|
||||||
|
cache_ram = cache_args.get("ram", 16.0)
|
||||||
|
self.init_ram_cache(cache_ram)
|
||||||
|
logging.info("Using RAM pressure cache.")
|
||||||
elif cache_type == CacheType.LRU:
|
elif cache_type == CacheType.LRU:
|
||||||
if cache_size is None:
|
cache_size = cache_args.get("lru", 0)
|
||||||
cache_size = 0
|
|
||||||
self.init_lru_cache(cache_size)
|
self.init_lru_cache(cache_size)
|
||||||
logging.info("Using LRU cache")
|
logging.info("Using LRU cache")
|
||||||
else:
|
else:
|
||||||
self.init_classic_cache()
|
self.init_classic_cache()
|
||||||
|
|
||||||
self.all = [self.outputs, self.ui, self.objects]
|
self.all = [self.outputs, self.objects]
|
||||||
|
|
||||||
# Performs like the old cache -- dump data ASAP
|
# Performs like the old cache -- dump data ASAP
|
||||||
def init_classic_cache(self):
|
def init_classic_cache(self):
|
||||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_lru_cache(self, cache_size):
|
def init_lru_cache(self, cache_size):
|
||||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
|
def init_ram_cache(self, min_headroom):
|
||||||
|
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_null_cache(self):
|
def init_null_cache(self):
|
||||||
self.outputs = NullCache()
|
self.outputs = NullCache()
|
||||||
#The UI cache is expected to be iterable at the end of each workflow
|
|
||||||
#so it must cache at least a full workflow. Use Heirachical
|
|
||||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
|
||||||
self.objects = NullCache()
|
self.objects = NullCache()
|
||||||
|
|
||||||
def recursive_debug_dump(self):
|
def recursive_debug_dump(self):
|
||||||
result = {
|
result = {
|
||||||
"outputs": self.outputs.recursive_debug_dump(),
|
"outputs": self.outputs.recursive_debug_dump(),
|
||||||
"ui": self.ui.recursive_debug_dump(),
|
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -157,14 +165,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
if execution_list is None:
|
if execution_list is None:
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue # This might be a lazily-evaluated input
|
continue # This might be a lazily-evaluated input
|
||||||
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
|
cached = execution_list.get_cache(input_unique_id, unique_id)
|
||||||
if cached_output is None:
|
if cached is None or cached.outputs is None:
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
if output_index >= len(cached_output):
|
if output_index >= len(cached.outputs):
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
obj = cached_output[output_index]
|
obj = cached.outputs[output_index]
|
||||||
input_data_all[x] = obj
|
input_data_all[x] = obj
|
||||||
elif input_category is not None:
|
elif input_category is not None:
|
||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
@ -393,7 +401,7 @@ def format_value(x):
|
|||||||
else:
|
else:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
|
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||||
@ -401,12 +409,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
if caches.outputs.get(unique_id) is not None:
|
cached = caches.outputs.get(unique_id)
|
||||||
|
if cached is not None:
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
cached_output = caches.ui.get(unique_id) or {}
|
cached_ui = cached.ui or {}
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||||
|
if cached.ui is not None:
|
||||||
|
ui_outputs[unique_id] = cached.ui
|
||||||
get_progress_state().finish_progress(unique_id)
|
get_progress_state().finish_progress(unique_id)
|
||||||
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
|
execution_list.cache_update(unique_id, cached)
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
input_data_all = None
|
input_data_all = None
|
||||||
@ -436,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
for r in result:
|
for r in result:
|
||||||
if is_link(r):
|
if is_link(r):
|
||||||
source_node, source_output = r[0], r[1]
|
source_node, source_output = r[0], r[1]
|
||||||
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
|
node_cached = execution_list.get_cache(source_node, unique_id)
|
||||||
for o in node_output:
|
for o in node_cached.outputs[source_output]:
|
||||||
resolved_output.append(o)
|
resolved_output.append(o)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -507,7 +518,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
asyncio.create_task(await_completion())
|
asyncio.create_task(await_completion())
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
caches.ui.set(unique_id, {
|
ui_outputs[unique_id] = {
|
||||||
"meta": {
|
"meta": {
|
||||||
"node_id": unique_id,
|
"node_id": unique_id,
|
||||||
"display_node": display_node_id,
|
"display_node": display_node_id,
|
||||||
@ -515,7 +526,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
"real_node_id": real_node_id,
|
"real_node_id": real_node_id,
|
||||||
},
|
},
|
||||||
"output": output_ui
|
"output": output_ui
|
||||||
})
|
}
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||||
if has_subgraph:
|
if has_subgraph:
|
||||||
@ -554,8 +565,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
pending_subgraph_results[unique_id] = cached_outputs
|
pending_subgraph_results[unique_id] = cached_outputs
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
|
||||||
caches.outputs.set(unique_id, output_data)
|
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
|
||||||
execution_list.cache_update(unique_id, output_data)
|
execution_list.cache_update(unique_id, cache_entry)
|
||||||
|
caches.outputs.set(unique_id, cache_entry)
|
||||||
|
|
||||||
except comfy.model_management.InterruptProcessingException as iex:
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
logging.info("Processing interrupted")
|
logging.info("Processing interrupted")
|
||||||
@ -600,14 +612,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server, cache_type=False, cache_size=None):
|
def __init__(self, server, cache_type=False, cache_args=None):
|
||||||
self.cache_size = cache_size
|
self.cache_args = cache_args
|
||||||
self.cache_type = cache_type
|
self.cache_type = cache_type
|
||||||
self.server = server
|
self.server = server
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
|
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.success = True
|
self.success = True
|
||||||
|
|
||||||
@ -682,6 +694,7 @@ class PromptExecutor:
|
|||||||
broadcast=False)
|
broadcast=False)
|
||||||
pending_subgraph_results = {}
|
pending_subgraph_results = {}
|
||||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||||
|
ui_node_outputs = {}
|
||||||
executed = set()
|
executed = set()
|
||||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||||
current_outputs = self.caches.outputs.all_node_ids()
|
current_outputs = self.caches.outputs.all_node_ids()
|
||||||
@ -695,7 +708,7 @@ class PromptExecutor:
|
|||||||
break
|
break
|
||||||
|
|
||||||
assert node_id is not None, "Node ID should not be None at this point"
|
assert node_id is not None, "Node ID should not be None at this point"
|
||||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||||
self.success = result != ExecutionResult.FAILURE
|
self.success = result != ExecutionResult.FAILURE
|
||||||
if result == ExecutionResult.FAILURE:
|
if result == ExecutionResult.FAILURE:
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
@ -704,16 +717,14 @@ class PromptExecutor:
|
|||||||
execution_list.unstage_node_execution()
|
execution_list.unstage_node_execution()
|
||||||
else: # result == ExecutionResult.SUCCESS:
|
else: # result == ExecutionResult.SUCCESS:
|
||||||
execution_list.complete_node_execution()
|
execution_list.complete_node_execution()
|
||||||
|
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||||
else:
|
else:
|
||||||
# Only execute when the while-loop ends without break
|
# Only execute when the while-loop ends without break
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||||
|
|
||||||
ui_outputs = {}
|
ui_outputs = {}
|
||||||
meta_outputs = {}
|
meta_outputs = {}
|
||||||
all_node_ids = self.caches.ui.all_node_ids()
|
for node_id, ui_info in ui_node_outputs.items():
|
||||||
for node_id in all_node_ids:
|
|
||||||
ui_info = self.caches.ui.get(node_id)
|
|
||||||
if ui_info is not None:
|
|
||||||
ui_outputs[node_id] = ui_info["output"]
|
ui_outputs[node_id] = ui_info["output"]
|
||||||
meta_outputs[node_id] = ui_info["meta"]
|
meta_outputs[node_id] = ui_info["meta"]
|
||||||
self.history_result = {
|
self.history_result = {
|
||||||
|
|||||||
4
main.py
4
main.py
@ -198,10 +198,12 @@ def prompt_worker(q, server_instance):
|
|||||||
cache_type = execution.CacheType.CLASSIC
|
cache_type = execution.CacheType.CLASSIC
|
||||||
if args.cache_lru > 0:
|
if args.cache_lru > 0:
|
||||||
cache_type = execution.CacheType.LRU
|
cache_type = execution.CacheType.LRU
|
||||||
|
elif args.cache_ram > 0:
|
||||||
|
cache_type = execution.CacheType.RAM_PRESSURE
|
||||||
elif args.cache_none:
|
elif args.cache_none:
|
||||||
cache_type = execution.CacheType.NONE
|
cache_type = execution.CacheType.NONE
|
||||||
|
|
||||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user