Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data 2025-10-31 07:31:50 +09:00
commit ad4b959d7e
22 changed files with 390 additions and 503 deletions

View File

@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

View File

@ -276,6 +276,9 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model) self.size = comfy.model_management.module_size(self.model)
return self.size return self.size
def get_ram_usage(self):
return self.model_size()
def loaded_size(self): def loaded_size(self):
return self.model.model_loaded_weight_memory return self.model.model_loaded_weight_memory
@ -655,6 +658,7 @@ class ModelPatcher:
mem_counter = 0 mem_counter = 0
patch_counter = 0 patch_counter = 0
lowvram_counter = 0 lowvram_counter = 0
lowvram_mem_counter = 0
loading = self._load_list() loading = self._load_list()
load_completely = [] load_completely = []
@ -675,6 +679,7 @@ class ModelPatcher:
if mem_counter + module_mem >= lowvram_model_memory: if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True lowvram_weight = True
lowvram_counter += 1 lowvram_counter += 1
lowvram_mem_counter += module_mem
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue continue
@ -748,10 +753,10 @@ class ModelPatcher:
self.pin_weight_to_device("{}.{}".format(n, param)) self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0: if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True self.model.model_lowvram = True
else: else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False self.model.model_lowvram = False
if full_load: if full_load:
self.model.to(device_to) self.model.to(device_to)

View File

@ -421,14 +421,18 @@ def fp8_linear(self, input):
if scale_input is None: if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight)
else: else:
scale_input = scale_input.to(input.device) scale_input = scale_input.to(input.device)
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
# Wrap weight in QuantizedTensor - this enables unified dispatch # Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
uncast_bias_weight(self, w, bias, offload_stream) uncast_bias_weight(self, w, bias, offload_stream)

View File

@ -357,9 +357,10 @@ class TensorCoreFP8Layout(QuantizedLayout):
scale = torch.tensor(scale) scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32) scale = scale.to(device=tensor.device, dtype=torch.float32)
lp_amax = torch.finfo(dtype).max
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) # TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
# lp_amax = torch.finfo(dtype).max
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
layout_params = { layout_params = {

View File

@ -143,6 +143,9 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model) return self.patcher.add_patches(patches, strength_patch, strength_model)
@ -293,6 +296,7 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False self.disable_offload = False
self.not_video = False self.not_video = False
self.size = None
self.downscale_index_formula = None self.downscale_index_formula = None
self.upscale_index_formula = None self.upscale_index_formula = None
@ -595,6 +599,16 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self): def throw_exception_if_invalid(self):
if self.first_stage_model is None: if self.first_stage_model is None:

View File

@ -1,15 +1,12 @@
from __future__ import annotations from __future__ import annotations
import aiohttp import aiohttp
import mimetypes import mimetypes
from typing import Optional, Union from typing import Union
from comfy.utils import common_upscale
from server import PromptServer from server import PromptServer
from comfy.cli_args import args
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import torch import torch
import math
import base64 import base64
from io import BytesIO from io import BytesIO
@ -60,85 +57,6 @@ async def validate_and_cast_response(
return torch.stack(image_tensors, dim=0) return torch.stack(image_tensors, dim=0)
def validate_aspect_ratio(
aspect_ratio: str,
minimum_ratio: float,
maximum_ratio: float,
minimum_ratio_str: str,
maximum_ratio_str: str,
) -> float:
"""Validates and casts an aspect ratio string to a float.
Args:
aspect_ratio: The aspect ratio string to validate.
minimum_ratio: The minimum aspect ratio.
maximum_ratio: The maximum aspect ratio.
minimum_ratio_str: The minimum aspect ratio string.
maximum_ratio_str: The maximum aspect ratio string.
Returns:
The validated and cast aspect ratio.
Raises:
Exception: If the aspect ratio is not valid.
"""
# get ratio values
numbers = aspect_ratio.split(":")
if len(numbers) != 2:
raise TypeError(
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
)
try:
numerator = int(numbers[0])
denominator = int(numbers[1])
except ValueError as exc:
raise TypeError(
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
) from exc
calculated_ratio = numerator / denominator
# if not close to minimum and maximum, check bounds
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
calculated_ratio, maximum_ratio
):
if calculated_ratio < minimum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
if calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
return aspect_ratio
async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
url: The URL to download.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
BytesIO object containing the downloaded content.
"""
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
def text_filepath_to_base64_string(filepath: str) -> str: def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string.""" """Converts a text file to a base64 string."""
with open(filepath, "rb") as f: with open(filepath, "rb") as f:
@ -153,28 +71,3 @@ def text_filepath_to_data_uri(filepath: str) -> str:
if mime_type is None: if mime_type is None:
mime_type = "application/octet-stream" mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}" return f"data:{mime_type};base64,{base64_string}"
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask

View File

@ -5,10 +5,6 @@ import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apinode_utils import (
resize_mask_to_image,
validate_aspect_ratio,
)
from comfy_api_nodes.apis.bfl_api import ( from comfy_api_nodes.apis.bfl_api import (
BFLFluxExpandImageRequest, BFLFluxExpandImageRequest,
BFLFluxFillImageRequest, BFLFluxFillImageRequest,
@ -23,8 +19,10 @@ from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
download_url_to_image_tensor, download_url_to_image_tensor,
poll_op, poll_op,
resize_mask_to_image,
sync_op, sync_op,
tensor_to_base64_string, tensor_to_base64_string,
validate_aspect_ratio_string,
validate_string, validate_string,
) )
@ -43,11 +41,6 @@ class FluxProUltraImageNode(IO.ComfyNode):
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
""" """
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
@ -112,16 +105,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
@classmethod @classmethod
def validate_inputs(cls, aspect_ratio: str): def validate_inputs(cls, aspect_ratio: str):
try: validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
except Exception as e:
return str(e)
return True return True
@classmethod @classmethod
@ -145,13 +129,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
prompt=prompt, prompt=prompt,
prompt_upsampling=prompt_upsampling, prompt_upsampling=prompt_upsampling,
seed=seed, seed=seed,
aspect_ratio=validate_aspect_ratio( aspect_ratio=aspect_ratio,
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
),
raw=raw, raw=raw,
image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)), image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)), image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
@ -180,11 +158,6 @@ class FluxKontextProImageNode(IO.ComfyNode):
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
""" """
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
@ -261,13 +234,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
seed=0, seed=0,
prompt_upsampling=False, prompt_upsampling=False,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
aspect_ratio = validate_aspect_ratio( validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
if input_image is None: if input_image is None:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
initial_response = await sync_op( initial_response = await sync_op(

View File

@ -17,7 +17,7 @@ from comfy_api_nodes.util import (
poll_op, poll_op,
sync_op, sync_op,
upload_images_to_comfyapi, upload_images_to_comfyapi,
validate_image_aspect_ratio_range, validate_image_aspect_ratio,
validate_image_dimensions, validate_image_dimensions,
validate_string, validate_string,
) )
@ -403,7 +403,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1: if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.") raise ValueError("Exactly one input image is required.")
validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) validate_image_aspect_ratio(image, (1, 3), (3, 1))
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
payload = Image2ImageTaskCreationRequest( payload = Image2ImageTaskCreationRequest(
model=model, model=model,
@ -565,7 +565,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
reference_images_urls = [] reference_images_urls = []
if n_input_images: if n_input_images:
for i in image: for i in image:
validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) validate_image_aspect_ratio(i, (1, 3), (3, 1))
reference_images_urls = await upload_images_to_comfyapi( reference_images_urls = await upload_images_to_comfyapi(
cls, cls,
image, image,
@ -798,7 +798,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
prompt = ( prompt = (
@ -923,7 +923,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
for i in (first_frame, last_frame): for i in (first_frame, last_frame):
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls, cls,
@ -1045,7 +1045,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"]) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"])
for image in images: for image in images:
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png") image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
prompt = ( prompt = (

View File

@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO from comfy_api.latest import IO, ComfyExtension
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import torch import torch
@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
IdeogramV3Request, IdeogramV3Request,
IdeogramV3EditRequest, IdeogramV3EditRequest,
) )
from comfy_api_nodes.util import (
from comfy_api_nodes.apis.client import (
ApiEndpoint, ApiEndpoint,
HttpMethod, bytesio_to_image_tensor,
SynchronousOperation, download_url_as_bytesio,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
resize_mask_to_image, resize_mask_to_image,
sync_op,
) )
from comfy_api_nodes.util import bytesio_to_image_tensor
from server import PromptServer
V1_V1_RES_MAP = { V1_V1_RES_MAP = {
"Auto":"AUTO", "Auto":"AUTO",
@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
for image_url in image_urls: for image_url in image_urls:
# Using functions from apinode_utils.py to handle downloading and processing # Using functions from apinode_utils.py to handle downloading and processing
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
image_tensors.append(img_tensor) image_tensors.append(img_tensor)
@ -233,19 +227,6 @@ async def download_and_process_images(image_urls):
return stacked_tensors return stacked_tensors
def display_image_urls_on_node(image_urls, node_id):
if node_id and image_urls:
if len(image_urls) == 1:
PromptServer.instance.send_progress_text(
f"Generated Image URL:\n{image_urls[0]}", node_id
)
else:
urls_text = "Generated Image URLs:\n" + "\n".join(
f"{i+1}. {url}" for i, url in enumerate(image_urls)
)
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(IO.ComfyNode): class IdeogramV1(IO.ComfyNode):
@classmethod @classmethod
@ -334,44 +315,30 @@ class IdeogramV1(IO.ComfyNode):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1" model = "V_1_TURBO" if turbo else "V_1"
auth = { response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
} response_model=IdeogramGenerateResponse,
operation = SynchronousOperation( data=IdeogramGenerateRequest(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
image_request=ImageRequest( image_request=ImageRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
num_images=num_images, num_images=num_images,
seed=seed, seed=seed,
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None, aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
magic_prompt_option=( magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
) )
), ),
auth_kwargs=auth, max_retries=1,
) )
response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url] image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls)) return IO.NodeOutput(await download_and_process_images(image_urls))
@ -500,18 +467,11 @@ class IdeogramV2(IO.ComfyNode):
else: else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
auth = { response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
} response_model=IdeogramGenerateResponse,
operation = SynchronousOperation( data=IdeogramGenerateRequest(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
image_request=ImageRequest( image_request=ImageRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
@ -519,28 +479,20 @@ class IdeogramV2(IO.ComfyNode):
seed=seed, seed=seed,
aspect_ratio=final_aspect_ratio, aspect_ratio=final_aspect_ratio,
resolution=final_resolution, resolution=final_resolution,
magic_prompt_option=( magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
style_type=style_type if style_type != "NONE" else None, style_type=style_type if style_type != "NONE" else None,
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
color_palette=color_palette if color_palette else None, color_palette=color_palette if color_palette else None,
) )
), ),
auth_kwargs=auth, max_retries=1,
) )
response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url] image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls)) return IO.NodeOutput(await download_and_process_images(image_urls))
@ -656,10 +608,6 @@ class IdeogramV3(IO.ComfyNode):
character_image=None, character_image=None,
character_mask=None, character_mask=None,
): ):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if rendering_speed == "BALANCED": # for backward compatibility if rendering_speed == "BALANCED": # for backward compatibility
rendering_speed = "DEFAULT" rendering_speed = "DEFAULT"
@ -694,9 +642,6 @@ class IdeogramV3(IO.ComfyNode):
# Check if both image and mask are provided for editing mode # Check if both image and mask are provided for editing mode
if image is not None and mask is not None: if image is not None and mask is not None:
# Edit mode
path = "/proxy/ideogram/ideogram-v3/edit"
# Process image and mask # Process image and mask
input_tensor = image.squeeze().cpu() input_tensor = image.squeeze().cpu()
# Resize mask to match image dimension # Resize mask to match image dimension
@ -749,27 +694,20 @@ class IdeogramV3(IO.ComfyNode):
if character_mask_binary: if character_mask_binary:
files["character_mask_binary"] = character_mask_binary files["character_mask_binary"] = character_mask_binary
# Execute the operation for edit mode response = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
path=path, response_model=IdeogramGenerateResponse,
method=HttpMethod.POST, data=edit_request,
request_model=IdeogramV3EditRequest,
response_model=IdeogramGenerateResponse,
),
request=edit_request,
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth, max_retries=1,
) )
elif image is not None or mask is not None: elif image is not None or mask is not None:
# If only one of image or mask is provided, raise an error # If only one of image or mask is provided, raise an error
raise Exception("Ideogram V3 image editing requires both an image AND a mask") raise Exception("Ideogram V3 image editing requires both an image AND a mask")
else: else:
# Generation mode
path = "/proxy/ideogram/ideogram-v3/generate"
# Create generation request # Create generation request
gen_request = IdeogramV3Request( gen_request = IdeogramV3Request(
prompt=prompt, prompt=prompt,
@ -800,32 +738,22 @@ class IdeogramV3(IO.ComfyNode):
if files: if files:
gen_request.style_type = "AUTO" gen_request.style_type = "AUTO"
# Execute the operation for generation mode response = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
path=path, response_model=IdeogramGenerateResponse,
method=HttpMethod.POST, data=gen_request,
request_model=IdeogramV3Request,
response_model=IdeogramGenerateResponse,
),
request=gen_request,
files=files if files else None, files=files if files else None,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth, max_retries=1,
) )
# Execute the operation and process response
response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url] image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls)) return IO.NodeOutput(await download_and_process_images(image_urls))
@ -838,5 +766,6 @@ class IdeogramExtension(ComfyExtension):
IdeogramV3, IdeogramV3,
] ]
async def comfy_entrypoint() -> IdeogramExtension: async def comfy_entrypoint() -> IdeogramExtension:
return IdeogramExtension() return IdeogramExtension()

View File

@ -282,7 +282,7 @@ def validate_input_image(image: torch.Tensor) -> None:
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
""" """
validate_image_dimensions(image, min_width=300, min_height=300) validate_image_dimensions(image, min_width=300, min_height=300)
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5) validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1))
def get_video_from_response(response) -> KlingVideoResult: def get_video_from_response(response) -> KlingVideoResult:

View File

@ -225,7 +225,7 @@ class OpenAIDalle2(ComfyNodeABC):
), ),
files=( files=(
{ {
"image": img_binary, "image": ("image.png", img_binary, "image/png"),
} }
if img_binary if img_binary
else None else None

View File

@ -1,7 +1,6 @@
from inspect import cleandoc import torch
from typing import Optional
from typing_extensions import override from typing_extensions import override
from io import BytesIO from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.pixverse_api import ( from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest, PixverseTextVideoRequest,
PixverseImageVideoRequest, PixverseImageVideoRequest,
@ -17,53 +16,30 @@ from comfy_api_nodes.apis.pixverse_api import (
PixverseIO, PixverseIO,
pixverse_templates, pixverse_templates,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
HttpMethod, download_url_to_video_output,
SynchronousOperation, poll_op,
PollingOperation, sync_op,
EmptyRequest, tensor_to_bytesio,
validate_string,
) )
from comfy_api_nodes.util import validate_string, tensor_to_bytesio
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
import torch
import aiohttp
AVERAGE_DURATION_T2V = 32 AVERAGE_DURATION_T2V = 32
AVERAGE_DURATION_I2V = 30 AVERAGE_DURATION_I2V = 30
AVERAGE_DURATION_T2T = 52 AVERAGE_DURATION_T2T = 52
def get_video_url_from_response( async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
response: PixverseGenerationStatusResponse, response_upload = await sync_op(
) -> Optional[str]: cls,
if response.Resp is None or response.Resp.url is None: ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
return None response_model=PixverseImageUploadResponse,
return str(response.Resp.url)
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/image/upload",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=PixverseImageUploadResponse,
),
request=EmptyRequest(),
files={"image": tensor_to_bytesio(image)}, files={"image": tensor_to_bytesio(image)},
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
) )
response_upload: PixverseImageUploadResponse = await operation.execute()
if response_upload.Resp is None: if response_upload.Resp is None:
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
return response_upload.Resp.img_id return response_upload.Resp.img_id
@ -93,17 +69,13 @@ class PixverseTemplateNode(IO.ComfyNode):
class PixverseTextToVideoNode(IO.ComfyNode): class PixverseTextToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="PixverseTextToVideoNode", node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video", display_name="PixVerse Text to Video",
category="api node/video/PixVerse", category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""), description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -170,7 +142,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
negative_prompt: str = None, negative_prompt: str = None,
pixverse_template: int = None, pixverse_template: int = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False, min_length=1)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p: if quality == PixverseQuality.res_1080p:
@ -179,18 +151,11 @@ class PixverseTextToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
} response_model=PixverseVideoResponse,
operation = SynchronousOperation( data=PixverseTextVideoRequest(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
method=HttpMethod.POST,
request_model=PixverseTextVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTextVideoRequest(
prompt=prompt, prompt=prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
quality=quality, quality=quality,
@ -200,20 +165,14 @@ class PixverseTextToVideoNode(IO.ComfyNode):
template_id=pixverse_template, template_id=pixverse_template,
seed=seed, seed=seed,
), ),
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
method=HttpMethod.GET, response_model=PixverseGenerationStatusResponse,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[ failed_statuses=[
PixverseStatus.contents_moderation, PixverseStatus.contents_moderation,
@ -221,30 +180,19 @@ class PixverseTextToVideoNode(IO.ComfyNode):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseImageToVideoNode(IO.ComfyNode): class PixverseImageToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="PixverseImageToVideoNode", node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video", display_name="PixVerse Image to Video",
category="api node/video/PixVerse", category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""), description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.String.Input( IO.String.Input(
@ -309,11 +257,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
pixverse_template: int = None, pixverse_template: int = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
auth = { img_id = await upload_image_to_pixverse(cls, image)
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -323,14 +267,11 @@ class PixverseImageToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/pixverse/video/img/generate", ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
method=HttpMethod.POST, response_model=PixverseVideoResponse,
request_model=PixverseImageVideoRequest, data=PixverseImageVideoRequest(
response_model=PixverseVideoResponse,
),
request=PixverseImageVideoRequest(
img_id=img_id, img_id=img_id,
prompt=prompt, prompt=prompt,
quality=quality, quality=quality,
@ -340,20 +281,15 @@ class PixverseImageToVideoNode(IO.ComfyNode):
template_id=pixverse_template, template_id=pixverse_template,
seed=seed, seed=seed,
), ),
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
method=HttpMethod.GET, response_model=PixverseGenerationStatusResponse,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[ failed_statuses=[
PixverseStatus.contents_moderation, PixverseStatus.contents_moderation,
@ -361,30 +297,19 @@ class PixverseImageToVideoNode(IO.ComfyNode):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V, estimated_duration=AVERAGE_DURATION_I2V,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseTransitionVideoNode(IO.ComfyNode): class PixverseTransitionVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="PixverseTransitionVideoNode", node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video", display_name="PixVerse Transition Video",
category="api node/video/PixVerse", category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""), description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.Image.Input("first_frame"), IO.Image.Input("first_frame"),
IO.Image.Input("last_frame"), IO.Image.Input("last_frame"),
@ -445,12 +370,8 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt: str = None, negative_prompt: str = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
auth = { first_frame_id = await upload_image_to_pixverse(cls, first_frame)
"auth_token": cls.hidden.auth_token_comfy_org, last_frame_id = await upload_image_to_pixverse(cls, last_frame)
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -460,14 +381,11 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/pixverse/video/transition/generate", ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
method=HttpMethod.POST, response_model=PixverseVideoResponse,
request_model=PixverseTransitionVideoRequest, data=PixverseTransitionVideoRequest(
response_model=PixverseVideoResponse,
),
request=PixverseTransitionVideoRequest(
first_frame_img=first_frame_id, first_frame_img=first_frame_id,
last_frame_img=last_frame_id, last_frame_img=last_frame_id,
prompt=prompt, prompt=prompt,
@ -477,20 +395,15 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
seed=seed, seed=seed,
), ),
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
method=HttpMethod.GET, response_model=PixverseGenerationStatusResponse,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[ failed_statuses=[
PixverseStatus.contents_moderation, PixverseStatus.contents_moderation,
@ -498,16 +411,9 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixVerseExtension(ComfyExtension): class PixVerseExtension(ComfyExtension):

View File

@ -8,9 +8,6 @@ from typing_extensions import override
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apinode_utils import (
resize_mask_to_image,
)
from comfy_api_nodes.apis.recraft_api import ( from comfy_api_nodes.apis.recraft_api import (
RecraftColor, RecraftColor,
RecraftColorChain, RecraftColorChain,
@ -28,6 +25,7 @@ from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
bytesio_to_image_tensor, bytesio_to_image_tensor,
download_url_as_bytesio, download_url_as_bytesio,
resize_mask_to_image,
sync_op, sync_op,
tensor_to_bytesio, tensor_to_bytesio,
validate_string, validate_string,

View File

@ -200,7 +200,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls, cls,
@ -290,7 +290,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls, cls,
@ -390,8 +390,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_dimensions(end_frame, max_width=7999, max_height=7999) validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
@ -475,7 +475,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
reference_images = None reference_images = None
if reference_image is not None: if reference_image is not None:
validate_image_dimensions(reference_image, max_width=7999, max_height=7999) validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls, cls,
reference_image, reference_image,

View File

@ -14,9 +14,9 @@ from comfy_api_nodes.util import (
poll_op, poll_op,
sync_op, sync_op,
upload_images_to_comfyapi, upload_images_to_comfyapi,
validate_aspect_ratio_closeness, validate_image_aspect_ratio,
validate_image_aspect_ratio_range,
validate_image_dimensions, validate_image_dimensions,
validate_images_aspect_ratio_closeness,
) )
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
@ -114,7 +114,7 @@ async def execute_task(
cls, cls,
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
status_extractor=lambda r: r.state.value, status_extractor=lambda r: r.state,
estimated_duration=estimated_duration, estimated_duration=estimated_duration,
) )
@ -307,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if get_number_of_images(image) > 1: if get_number_of_images(image) > 1:
raise ValueError("Only one input image is allowed.") raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) validate_image_aspect_ratio(image, (1, 4), (4, 1))
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model_name=model,
prompt=prompt, prompt=prompt,
@ -423,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
if a > 7: if a > 7:
raise ValueError("Too many images, maximum allowed is 7.") raise ValueError("Too many images, maximum allowed is 7.")
for image in images: for image in images:
validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_image_dimensions(image, min_width=128, min_height=128) validate_image_dimensions(image, min_width=128, min_height=128)
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model_name=model,
@ -533,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
resolution: str, resolution: str,
movement_amplitude: str, movement_amplitude: str,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model_name=model,
prompt=prompt, prompt=prompt,

View File

@ -14,6 +14,7 @@ from .conversions import (
downscale_image_tensor, downscale_image_tensor,
image_tensor_pair_to_batch, image_tensor_pair_to_batch,
pil_to_bytesio, pil_to_bytesio,
resize_mask_to_image,
tensor_to_base64_string, tensor_to_base64_string,
tensor_to_bytesio, tensor_to_bytesio,
tensor_to_pil, tensor_to_pil,
@ -34,12 +35,12 @@ from .upload_helpers import (
) )
from .validation_utils import ( from .validation_utils import (
get_number_of_images, get_number_of_images,
validate_aspect_ratio_closeness, validate_aspect_ratio_string,
validate_audio_duration, validate_audio_duration,
validate_container_format_is_mp4, validate_container_format_is_mp4,
validate_image_aspect_ratio, validate_image_aspect_ratio,
validate_image_aspect_ratio_range,
validate_image_dimensions, validate_image_dimensions,
validate_images_aspect_ratio_closeness,
validate_string, validate_string,
validate_video_dimensions, validate_video_dimensions,
validate_video_duration, validate_video_duration,
@ -70,6 +71,7 @@ __all__ = [
"downscale_image_tensor", "downscale_image_tensor",
"image_tensor_pair_to_batch", "image_tensor_pair_to_batch",
"pil_to_bytesio", "pil_to_bytesio",
"resize_mask_to_image",
"tensor_to_base64_string", "tensor_to_base64_string",
"tensor_to_bytesio", "tensor_to_bytesio",
"tensor_to_pil", "tensor_to_pil",
@ -77,12 +79,12 @@ __all__ = [
"video_to_base64_string", "video_to_base64_string",
# Validation utilities # Validation utilities
"get_number_of_images", "get_number_of_images",
"validate_aspect_ratio_closeness", "validate_aspect_ratio_string",
"validate_audio_duration", "validate_audio_duration",
"validate_container_format_is_mp4", "validate_container_format_is_mp4",
"validate_image_aspect_ratio", "validate_image_aspect_ratio",
"validate_image_aspect_ratio_range",
"validate_image_dimensions", "validate_image_dimensions",
"validate_images_aspect_ratio_closeness",
"validate_string", "validate_string",
"validate_video_dimensions", "validate_video_dimensions",
"validate_video_duration", "validate_video_duration",

View File

@ -430,3 +430,24 @@ def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
wav = torch.cat(frames, dim=1) # [C, T] wav = torch.cat(frames, dim=1) # [C, T]
wav = _f32_pcm(wav) wav = _f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
_, height, width, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask

View File

@ -37,63 +37,62 @@ def validate_image_dimensions(
def validate_image_aspect_ratio( def validate_image_aspect_ratio(
image: torch.Tensor, image: torch.Tensor,
min_aspect_ratio: Optional[float] = None, min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_aspect_ratio: Optional[float] = None, max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
):
width, height = get_image_dimensions(image)
aspect_ratio = width / height
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}")
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}")
def validate_image_aspect_ratio_range(
image: torch.Tensor,
min_ratio: tuple[float, float], # e.g. (1, 4)
max_ratio: tuple[float, float], # e.g. (4, 1)
*, *,
strict: bool = True, # True -> (min, max); False -> [min, max] strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float: ) -> float:
a1, b1 = min_ratio """Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked."""
a2, b2 = max_ratio
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
lo, hi = (a1 / b1), (a2 / b2)
if lo > hi:
lo, hi = hi, lo
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
w, h = get_image_dimensions(image) w, h = get_image_dimensions(image)
if w <= 0 or h <= 0: if w <= 0 or h <= 0:
raise ValueError(f"Invalid image dimensions: {w}x{h}") raise ValueError(f"Invalid image dimensions: {w}x{h}")
ar = w / h ar = w / h
ok = (lo < ar < hi) if strict else (lo <= ar <= hi) _assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
if not ok:
op = "<" if strict else ""
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
return ar return ar
def validate_aspect_ratio_closeness( def validate_images_aspect_ratio_closeness(
start_img, first_image: torch.Tensor,
end_img, second_image: torch.Tensor,
min_rel: float, min_rel: float, # e.g. 0.8
max_rel: float, max_rel: float, # e.g. 1.25
*, *,
strict: bool = False, # True => exclusive, False => inclusive strict: bool = False, # True -> (min, max); False -> [min, max]
) -> None: ) -> float:
w1, h1 = get_image_dimensions(start_img) """
w2, h2 = get_image_dimensions(end_img) Validates that the two images' aspect ratios are 'close'.
The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1).
We require C <= limit, where limit = max(max_rel, 1.0 / min_rel).
Returns the computed closeness factor C.
"""
w1, h1 = get_image_dimensions(first_image)
w2, h2 = get_image_dimensions(second_image)
if min(w1, h1, w2, h2) <= 0: if min(w1, h1, w2, h2) <= 0:
raise ValueError("Invalid image dimensions") raise ValueError("Invalid image dimensions")
ar1 = w1 / h1 ar1 = w1 / h1
ar2 = w2 / h2 ar2 = w2 / h2
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
closeness = max(ar1, ar2) / min(ar1, ar2) closeness = max(ar1, ar2) / min(ar1, ar2)
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25 limit = max(max_rel, 1.0 / min_rel)
if (closeness >= limit) if strict else (closeness > limit): if (closeness >= limit) if strict else (closeness > limit):
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}{max_rel}.") raise ValueError(
f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, "
f"allowed range {min_rel}{max_rel} (limit {limit:.2g})."
)
return closeness
def validate_aspect_ratio_string(
aspect_ratio: str,
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
"""Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio."""
ar = _parse_aspect_ratio_string(aspect_ratio)
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
return ar
def validate_video_dimensions( def validate_video_dimensions(
@ -183,3 +182,49 @@ def validate_container_format_is_mp4(video: VideoInput) -> None:
container_format = video.get_container_format() container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}") raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
def _ratio_from_tuple(r: tuple[float, float]) -> float:
a, b = r
if a <= 0 or b <= 0:
raise ValueError(f"Ratios must be positive, got {a}:{b}.")
return a / b
def _assert_ratio_bounds(
ar: float,
*,
min_ratio: Optional[tuple[float, float]] = None,
max_ratio: Optional[tuple[float, float]] = None,
strict: bool = True,
) -> None:
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None
hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None
if lo is not None and hi is not None and lo > hi:
lo, hi = hi, lo # normalize order if caller swapped them
if lo is not None:
if (ar <= lo) if strict else (ar < lo):
op = "<" if strict else ""
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.")
if hi is not None:
if (ar >= hi) if strict else (ar > hi):
op = "<" if strict else ""
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.")
def _parse_aspect_ratio_string(ar_str: str) -> float:
"""Parse 'X:Y' with integer parts into a positive float ratio X/Y."""
parts = ar_str.split(":")
if len(parts) != 2:
raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.")
try:
a = int(parts[0].strip())
b = int(parts[1].strip())
except ValueError as exc:
raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc
if a <= 0 or b <= 0:
raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.")
return a / b

View File

@ -1,4 +1,9 @@
import bisect
import gc
import itertools import itertools
import psutil
import time
import torch
from typing import Sequence, Mapping, Dict from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -188,6 +193,9 @@ class BasicCache:
self._clean_cache() self._clean_cache()
self._clean_subcaches() self._clean_subcaches()
def poll(self, **kwargs):
pass
def _set_immediate(self, node_id, value): def _set_immediate(self, node_id, value):
assert self.initialized assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
@ -276,6 +284,9 @@ class NullCache:
def clean_unused(self): def clean_unused(self):
pass pass
def poll(self, **kwargs):
pass
def get(self, node_id): def get(self, node_id):
return None return None
@ -336,3 +347,75 @@ class LRUCache(BasicCache):
self._mark_used(child_id) self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self return self
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
RAM_CACHE_HYSTERESIS = 1.1
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
def __init__(self, key_class):
super().__init__(key_class, 0)
self.timestamps = {}
def clean_unused(self):
self._clean_subcaches()
def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value)
def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return
clean_list = []
for key, (outputs, _), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
def scan_list_for_ram_usage(outputs):
nonlocal ram_usage
for output in outputs:
if isinstance(output, list):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()

View File

@ -209,10 +209,15 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id) self.execution_cache_listeners[from_node_id].add(to_node_id)
def get_output_cache(self, from_node_id, to_node_id): def get_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache: if not to_node_id in self.execution_cache:
return None return None
return self.execution_cache[to_node_id].get(from_node_id) value = self.execution_cache[to_node_id].get(from_node_id)
if value is None:
return None
#Write back to the main cache on touch.
self.output_cache.set(from_node_id, value)
return value
def cache_update(self, node_id, value): def cache_update(self, node_id, value):
if node_id in self.execution_cache_listeners: if node_id in self.execution_cache_listeners:

View File

@ -21,6 +21,7 @@ from comfy_execution.caching import (
NullCache, NullCache,
HierarchicalCache, HierarchicalCache,
LRUCache, LRUCache,
RAMPressureCache,
) )
from comfy_execution.graph import ( from comfy_execution.graph import (
DynamicPrompt, DynamicPrompt,
@ -88,49 +89,56 @@ class IsChangedCache:
return self.is_changed[node_id] return self.is_changed[node_id]
class CacheEntry(NamedTuple):
ui: dict
outputs: list
class CacheType(Enum): class CacheType(Enum):
CLASSIC = 0 CLASSIC = 0
LRU = 1 LRU = 1
NONE = 2 NONE = 2
RAM_PRESSURE = 3
class CacheSet: class CacheSet:
def __init__(self, cache_type=None, cache_size=None): def __init__(self, cache_type=None, cache_args={}):
if cache_type == CacheType.NONE: if cache_type == CacheType.NONE:
self.init_null_cache() self.init_null_cache()
logging.info("Disabling intermediate node cache.") logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU: elif cache_type == CacheType.LRU:
if cache_size is None: cache_size = cache_args.get("lru", 0)
cache_size = 0
self.init_lru_cache(cache_size) self.init_lru_cache(cache_size)
logging.info("Using LRU cache") logging.info("Using LRU cache")
else: else:
self.init_classic_cache() self.init_classic_cache()
self.all = [self.outputs, self.ui, self.objects] self.all = [self.outputs, self.objects]
# Performs like the old cache -- dump data ASAP # Performs like the old cache -- dump data ASAP
def init_classic_cache(self): def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature) self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size): def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self): def init_null_cache(self):
self.outputs = NullCache() self.outputs = NullCache()
#The UI cache is expected to be iterable at the end of each workflow
#so it must cache at least a full workflow. Use Heirachical
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = NullCache() self.objects = NullCache()
def recursive_debug_dump(self): def recursive_debug_dump(self):
result = { result = {
"outputs": self.outputs.recursive_debug_dump(), "outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
} }
return result return result
@ -157,14 +165,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
if execution_list is None: if execution_list is None:
mark_missing() mark_missing()
continue # This might be a lazily-evaluated input continue # This might be a lazily-evaluated input
cached_output = execution_list.get_output_cache(input_unique_id, unique_id) cached = execution_list.get_cache(input_unique_id, unique_id)
if cached_output is None: if cached is None or cached.outputs is None:
mark_missing() mark_missing()
continue continue
if output_index >= len(cached_output): if output_index >= len(cached.outputs):
mark_missing() mark_missing()
continue continue
obj = cached_output[output_index] obj = cached.outputs[output_index]
input_data_all[x] = obj input_data_all[x] = obj
elif input_category is not None: elif input_category is not None:
input_data_all[x] = [input_data] input_data_all[x] = [input_data]
@ -393,7 +401,7 @@ def format_value(x):
else: else:
return str(x) return str(x)
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes): async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id) real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id)
@ -401,12 +409,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs'] inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type'] class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None: cached = caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None: if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {} cached_ui = cached.ui or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
get_progress_state().finish_progress(unique_id) get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
input_data_all = None input_data_all = None
@ -436,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
for r in result: for r in result:
if is_link(r): if is_link(r):
source_node, source_output = r[0], r[1] source_node, source_output = r[0], r[1]
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] node_cached = execution_list.get_cache(source_node, unique_id)
for o in node_output: for o in node_cached.outputs[source_output]:
resolved_output.append(o) resolved_output.append(o)
else: else:
@ -507,7 +518,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
asyncio.create_task(await_completion()) asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0: if len(output_ui) > 0:
caches.ui.set(unique_id, { ui_outputs[unique_id] = {
"meta": { "meta": {
"node_id": unique_id, "node_id": unique_id,
"display_node": display_node_id, "display_node": display_node_id,
@ -515,7 +526,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
"real_node_id": real_node_id, "real_node_id": real_node_id,
}, },
"output": output_ui "output": output_ui
}) }
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph: if has_subgraph:
@ -554,8 +565,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
pending_subgraph_results[unique_id] = cached_outputs pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data) cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, output_data) execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex: except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted") logging.info("Processing interrupted")
@ -600,14 +612,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor: class PromptExecutor:
def __init__(self, server, cache_type=False, cache_size=None): def __init__(self, server, cache_type=False, cache_args=None):
self.cache_size = cache_size self.cache_args = cache_args
self.cache_type = cache_type self.cache_type = cache_type
self.server = server self.server = server
self.reset() self.reset()
def reset(self): def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.status_messages = [] self.status_messages = []
self.success = True self.success = True
@ -682,6 +694,7 @@ class PromptExecutor:
broadcast=False) broadcast=False)
pending_subgraph_results = {} pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set() executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids() current_outputs = self.caches.outputs.all_node_ids()
@ -695,7 +708,7 @@ class PromptExecutor:
break break
assert node_id is not None, "Node ID should not be None at this point" assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE: if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@ -704,18 +717,16 @@ class PromptExecutor:
execution_list.unstage_node_execution() execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS: else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution() execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else: else:
# Only execute when the while-loop ends without break # Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {} ui_outputs = {}
meta_outputs = {} meta_outputs = {}
all_node_ids = self.caches.ui.all_node_ids() for node_id, ui_info in ui_node_outputs.items():
for node_id in all_node_ids: ui_outputs[node_id] = ui_info["output"]
ui_info = self.caches.ui.get(node_id) meta_outputs[node_id] = ui_info["meta"]
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = { self.history_result = {
"outputs": ui_outputs, "outputs": ui_outputs,
"meta": meta_outputs, "meta": meta_outputs,

View File

@ -198,10 +198,12 @@ def prompt_worker(q, server_instance):
cache_type = execution.CacheType.CLASSIC cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0: if args.cache_lru > 0:
cache_type = execution.CacheType.LRU cache_type = execution.CacheType.LRU
elif args.cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none: elif args.cache_none:
cache_type = execution.CacheType.NONE cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False
gc_collect_interval = 10.0 gc_collect_interval = 10.0