diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 432571616..7a2b76322 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group() cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") +cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 74b9e48bc..674a214ca 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -276,6 +276,9 @@ class ModelPatcher: self.size = comfy.model_management.module_size(self.model) return self.size + def get_ram_usage(self): + return self.model_size() + def loaded_size(self): return self.model.model_loaded_weight_memory @@ -655,6 +658,7 @@ class ModelPatcher: mem_counter = 0 patch_counter = 0 lowvram_counter = 0 + lowvram_mem_counter = 0 loading = self._load_list() load_completely = [] @@ -675,6 +679,7 @@ class ModelPatcher: if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True lowvram_counter += 1 + lowvram_mem_counter += module_mem if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed continue @@ -748,10 +753,10 @@ class ModelPatcher: self.pin_weight_to_device("{}.{}".format(n, param)) if lowvram_counter > 0: - logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: - logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) self.model.model_lowvram = False if full_load: self.model.to(device_to) diff --git a/comfy/ops.py b/comfy/ops.py index 71ca7a2bd..18f6b804b 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -421,14 +421,18 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() + layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} + quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight) else: scale_input = scale_input.to(input.device) + quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) - quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) uncast_bias_weight(self, w, bias, offload_stream) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index fb35a0d40..c822fe53c 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -357,9 +357,10 @@ class TensorCoreFP8Layout(QuantizedLayout): scale = torch.tensor(scale) scale = scale.to(device=tensor.device, dtype=torch.float32) - lp_amax = torch.finfo(dtype).max tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) - torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + # TODO: uncomment this if it's actually needed because the clamp has a small performance penality' + # lp_amax = torch.finfo(dtype).max + # torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) layout_params = { diff --git a/comfy/sd.py b/comfy/sd.py index de4eee96e..9e5ebbf15 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -143,6 +143,9 @@ class CLIP: n.apply_hooks_to_conds = self.apply_hooks_to_conds return n + def get_ram_usage(self): + return self.patcher.get_ram_usage() + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): return self.patcher.add_patches(patches, strength_patch, strength_model) @@ -293,6 +296,7 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False self.not_video = False + self.size = None self.downscale_index_formula = None self.upscale_index_formula = None @@ -595,6 +599,16 @@ class VAE: self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + self.model_size() + + def model_size(self): + if self.size is not None: + return self.size + self.size = comfy.model_management.module_size(self.first_stage_model) + return self.size + + def get_ram_usage(self): + return self.model_size() def throw_exception_if_invalid(self): if self.first_stage_model is None: diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index 6a72b9d1d..ecd604ff8 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -1,15 +1,12 @@ from __future__ import annotations import aiohttp import mimetypes -from typing import Optional, Union -from comfy.utils import common_upscale +from typing import Union from server import PromptServer -from comfy.cli_args import args import numpy as np from PIL import Image import torch -import math import base64 from io import BytesIO @@ -60,85 +57,6 @@ async def validate_and_cast_response( return torch.stack(image_tensors, dim=0) -def validate_aspect_ratio( - aspect_ratio: str, - minimum_ratio: float, - maximum_ratio: float, - minimum_ratio_str: str, - maximum_ratio_str: str, -) -> float: - """Validates and casts an aspect ratio string to a float. - - Args: - aspect_ratio: The aspect ratio string to validate. - minimum_ratio: The minimum aspect ratio. - maximum_ratio: The maximum aspect ratio. - minimum_ratio_str: The minimum aspect ratio string. - maximum_ratio_str: The maximum aspect ratio string. - - Returns: - The validated and cast aspect ratio. - - Raises: - Exception: If the aspect ratio is not valid. - """ - # get ratio values - numbers = aspect_ratio.split(":") - if len(numbers) != 2: - raise TypeError( - f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}." - ) - try: - numerator = int(numbers[0]) - denominator = int(numbers[1]) - except ValueError as exc: - raise TypeError( - f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}." - ) from exc - calculated_ratio = numerator / denominator - # if not close to minimum and maximum, check bounds - if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose( - calculated_ratio, maximum_ratio - ): - if calculated_ratio < minimum_ratio: - raise TypeError( - f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})." - ) - if calculated_ratio > maximum_ratio: - raise TypeError( - f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})." - ) - return aspect_ratio - - -async def download_url_to_bytesio( - url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None -) -> BytesIO: - """Downloads content from a URL using requests and returns it as BytesIO. - - Args: - url: The URL to download. - timeout: Request timeout in seconds. Defaults to None (no timeout). - - Returns: - BytesIO object containing the downloaded content. - """ - headers = {} - if url.startswith("/proxy/"): - url = str(args.comfy_api_base).rstrip("/") + url - auth_token = auth_kwargs.get("auth_token") - comfy_api_key = auth_kwargs.get("comfy_api_key") - if auth_token: - headers["Authorization"] = f"Bearer {auth_token}" - elif comfy_api_key: - headers["X-API-KEY"] = comfy_api_key - timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None - async with aiohttp.ClientSession(timeout=timeout_cfg) as session: - async with session.get(url, headers=headers) as resp: - resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) - return BytesIO(await resp.read()) - - def text_filepath_to_base64_string(filepath: str) -> str: """Converts a text file to a base64 string.""" with open(filepath, "rb") as f: @@ -153,28 +71,3 @@ def text_filepath_to_data_uri(filepath: str) -> str: if mime_type is None: mime_type = "application/octet-stream" return f"data:{mime_type};base64,{base64_string}" - - -def resize_mask_to_image( - mask: torch.Tensor, - image: torch.Tensor, - upscale_method="nearest-exact", - crop="disabled", - allow_gradient=True, - add_channel_dim=False, -): - """ - Resize mask to be the same dimensions as an image, while maintaining proper format for API calls. - """ - _, H, W, _ = image.shape - mask = mask.unsqueeze(-1) - mask = mask.movedim(-1, 1) - mask = common_upscale( - mask, width=W, height=H, upscale_method=upscale_method, crop=crop - ) - mask = mask.movedim(1, -1) - if not add_channel_dim: - mask = mask.squeeze(-1) - if not allow_gradient: - mask = (mask > 0.5).float() - return mask diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index baa74fd52..1740fb377 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -5,10 +5,6 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apinode_utils import ( - resize_mask_to_image, - validate_aspect_ratio, -) from comfy_api_nodes.apis.bfl_api import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, @@ -23,8 +19,10 @@ from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, poll_op, + resize_mask_to_image, sync_op, tensor_to_base64_string, + validate_aspect_ratio_string, validate_string, ) @@ -43,11 +41,6 @@ class FluxProUltraImageNode(IO.ComfyNode): Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. """ - MINIMUM_RATIO = 1 / 4 - MAXIMUM_RATIO = 4 / 1 - MINIMUM_RATIO_STR = "1:4" - MAXIMUM_RATIO_STR = "4:1" - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( @@ -112,16 +105,7 @@ class FluxProUltraImageNode(IO.ComfyNode): @classmethod def validate_inputs(cls, aspect_ratio: str): - try: - validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ) - except Exception as e: - return str(e) + validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1)) return True @classmethod @@ -145,13 +129,7 @@ class FluxProUltraImageNode(IO.ComfyNode): prompt=prompt, prompt_upsampling=prompt_upsampling, seed=seed, - aspect_ratio=validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ), + aspect_ratio=aspect_ratio, raw=raw, image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)), image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)), @@ -180,11 +158,6 @@ class FluxKontextProImageNode(IO.ComfyNode): Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. """ - MINIMUM_RATIO = 1 / 4 - MAXIMUM_RATIO = 4 / 1 - MINIMUM_RATIO_STR = "1:4" - MAXIMUM_RATIO_STR = "4:1" - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( @@ -261,13 +234,7 @@ class FluxKontextProImageNode(IO.ComfyNode): seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: - aspect_ratio = validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ) + validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1)) if input_image is None: validate_string(prompt, strip_whitespace=False) initial_response = await sync_op( diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 534af380d..caced471e 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -17,7 +17,7 @@ from comfy_api_nodes.util import ( poll_op, sync_op, upload_images_to_comfyapi, - validate_image_aspect_ratio_range, + validate_image_aspect_ratio, validate_image_dimensions, validate_string, ) @@ -403,7 +403,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): validate_string(prompt, strip_whitespace=True, min_length=1) if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") - validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) + validate_image_aspect_ratio(image, (1, 3), (3, 1)) source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] payload = Image2ImageTaskCreationRequest( model=model, @@ -565,7 +565,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): reference_images_urls = [] if n_input_images: for i in image: - validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) + validate_image_aspect_ratio(i, (1, 3), (3, 1)) reference_images_urls = await upload_images_to_comfyapi( cls, image, @@ -798,7 +798,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) - validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] prompt = ( @@ -923,7 +923,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) for i in (first_frame, last_frame): validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) - validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 download_urls = await upload_images_to_comfyapi( cls, @@ -1045,7 +1045,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"]) for image in images: validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) - validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png") prompt = ( diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index d8fd3378b..48f94e612 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -1,6 +1,6 @@ from io import BytesIO from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO +from comfy_api.latest import IO, ComfyExtension from PIL import Image import numpy as np import torch @@ -11,19 +11,13 @@ from comfy_api_nodes.apis import ( IdeogramV3Request, IdeogramV3EditRequest, ) - -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, -) - -from comfy_api_nodes.apinode_utils import ( - download_url_to_bytesio, + bytesio_to_image_tensor, + download_url_as_bytesio, resize_mask_to_image, + sync_op, ) -from comfy_api_nodes.util import bytesio_to_image_tensor -from server import PromptServer V1_V1_RES_MAP = { "Auto":"AUTO", @@ -220,7 +214,7 @@ async def download_and_process_images(image_urls): for image_url in image_urls: # Using functions from apinode_utils.py to handle downloading and processing - image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO + image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode image_tensors.append(img_tensor) @@ -233,19 +227,6 @@ async def download_and_process_images(image_urls): return stacked_tensors -def display_image_urls_on_node(image_urls, node_id): - if node_id and image_urls: - if len(image_urls) == 1: - PromptServer.instance.send_progress_text( - f"Generated Image URL:\n{image_urls[0]}", node_id - ) - else: - urls_text = "Generated Image URLs:\n" + "\n".join( - f"{i+1}. {url}" for i, url in enumerate(image_urls) - ) - PromptServer.instance.send_progress_text(urls_text, node_id) - - class IdeogramV1(IO.ComfyNode): @classmethod @@ -334,44 +315,30 @@ class IdeogramV1(IO.ComfyNode): aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) model = "V_1_TURBO" if turbo else "V_1" - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/ideogram/generate", - method=HttpMethod.POST, - request_model=IdeogramGenerateRequest, - response_model=IdeogramGenerateResponse, - ), - request=IdeogramGenerateRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/ideogram/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=IdeogramGenerateRequest( image_request=ImageRequest( prompt=prompt, model=model, num_images=num_images, seed=seed, aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None, - magic_prompt_option=( - magic_prompt_option if magic_prompt_option != "AUTO" else None - ), + magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None), negative_prompt=negative_prompt if negative_prompt else None, ) ), - auth_kwargs=auth, + max_retries=1, ) - response = await operation.execute() - if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") image_urls = [image_data.url for image_data in response.data if image_data.url] - if not image_urls: raise Exception("No image URLs were generated in the response") - - display_image_urls_on_node(image_urls, cls.hidden.unique_id) return IO.NodeOutput(await download_and_process_images(image_urls)) @@ -500,18 +467,11 @@ class IdeogramV2(IO.ComfyNode): else: final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/ideogram/generate", - method=HttpMethod.POST, - request_model=IdeogramGenerateRequest, - response_model=IdeogramGenerateResponse, - ), - request=IdeogramGenerateRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=IdeogramGenerateRequest( image_request=ImageRequest( prompt=prompt, model=model, @@ -519,28 +479,20 @@ class IdeogramV2(IO.ComfyNode): seed=seed, aspect_ratio=final_aspect_ratio, resolution=final_resolution, - magic_prompt_option=( - magic_prompt_option if magic_prompt_option != "AUTO" else None - ), + magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None), style_type=style_type if style_type != "NONE" else None, negative_prompt=negative_prompt if negative_prompt else None, color_palette=color_palette if color_palette else None, ) ), - auth_kwargs=auth, + max_retries=1, ) - - response = await operation.execute() - if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") image_urls = [image_data.url for image_data in response.data if image_data.url] - if not image_urls: raise Exception("No image URLs were generated in the response") - - display_image_urls_on_node(image_urls, cls.hidden.unique_id) return IO.NodeOutput(await download_and_process_images(image_urls)) @@ -656,10 +608,6 @@ class IdeogramV3(IO.ComfyNode): character_image=None, character_mask=None, ): - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } if rendering_speed == "BALANCED": # for backward compatibility rendering_speed = "DEFAULT" @@ -694,9 +642,6 @@ class IdeogramV3(IO.ComfyNode): # Check if both image and mask are provided for editing mode if image is not None and mask is not None: - # Edit mode - path = "/proxy/ideogram/ideogram-v3/edit" - # Process image and mask input_tensor = image.squeeze().cpu() # Resize mask to match image dimension @@ -749,27 +694,20 @@ class IdeogramV3(IO.ComfyNode): if character_mask_binary: files["character_mask_binary"] = character_mask_binary - # Execute the operation for edit mode - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=IdeogramV3EditRequest, - response_model=IdeogramGenerateResponse, - ), - request=edit_request, + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"), + response_model=IdeogramGenerateResponse, + data=edit_request, files=files, content_type="multipart/form-data", - auth_kwargs=auth, + max_retries=1, ) elif image is not None or mask is not None: # If only one of image or mask is provided, raise an error raise Exception("Ideogram V3 image editing requires both an image AND a mask") else: - # Generation mode - path = "/proxy/ideogram/ideogram-v3/generate" - # Create generation request gen_request = IdeogramV3Request( prompt=prompt, @@ -800,32 +738,22 @@ class IdeogramV3(IO.ComfyNode): if files: gen_request.style_type = "AUTO" - # Execute the operation for generation mode - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=IdeogramV3Request, - response_model=IdeogramGenerateResponse, - ), - request=gen_request, + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=gen_request, files=files if files else None, content_type="multipart/form-data", - auth_kwargs=auth, + max_retries=1, ) - # Execute the operation and process response - response = await operation.execute() - if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") image_urls = [image_data.url for image_data in response.data if image_data.url] - if not image_urls: raise Exception("No image URLs were generated in the response") - - display_image_urls_on_node(image_urls, cls.hidden.unique_id) return IO.NodeOutput(await download_and_process_images(image_urls)) @@ -838,5 +766,6 @@ class IdeogramExtension(ComfyExtension): IdeogramV3, ] + async def comfy_entrypoint() -> IdeogramExtension: return IdeogramExtension() diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index eea65c9ac..7b23e9cf9 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -282,7 +282,7 @@ def validate_input_image(image: torch.Tensor) -> None: See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo """ validate_image_dimensions(image, min_width=300, min_height=300) - validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5) + validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1)) def get_video_from_response(response) -> KlingVideoResult: diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index c467e840c..b4568fc85 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -225,7 +225,7 @@ class OpenAIDalle2(ComfyNodeABC): ), files=( { - "image": img_binary, + "image": ("image.png", img_binary, "image/png"), } if img_binary else None diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index b2b841be8..6e1686af0 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,7 +1,6 @@ -from inspect import cleandoc -from typing import Optional +import torch from typing_extensions import override -from io import BytesIO +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.pixverse_api import ( PixverseTextVideoRequest, PixverseImageVideoRequest, @@ -17,53 +16,30 @@ from comfy_api_nodes.apis.pixverse_api import ( PixverseIO, pixverse_templates, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, + download_url_to_video_output, + poll_op, + sync_op, + tensor_to_bytesio, + validate_string, ) -from comfy_api_nodes.util import validate_string, tensor_to_bytesio -from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import ComfyExtension, IO - -import torch -import aiohttp - AVERAGE_DURATION_T2V = 32 AVERAGE_DURATION_I2V = 30 AVERAGE_DURATION_T2T = 52 -def get_video_url_from_response( - response: PixverseGenerationStatusResponse, -) -> Optional[str]: - if response.Resp is None or response.Resp.url is None: - return None - return str(response.Resp.url) - - -async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): - # first, upload image to Pixverse and get image id to use in actual generation call - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/image/upload", - method=HttpMethod.POST, - request_model=EmptyRequest, - response_model=PixverseImageUploadResponse, - ), - request=EmptyRequest(), +async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor): + response_upload = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"), + response_model=PixverseImageUploadResponse, files={"image": tensor_to_bytesio(image)}, content_type="multipart/form-data", - auth_kwargs=auth_kwargs, ) - response_upload: PixverseImageUploadResponse = await operation.execute() - if response_upload.Resp is None: raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") - return response_upload.Resp.img_id @@ -93,17 +69,13 @@ class PixverseTemplateNode(IO.ComfyNode): class PixverseTextToVideoNode(IO.ComfyNode): - """ - Generates videos based on prompt and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="PixverseTextToVideoNode", display_name="PixVerse Text to Video", category="api node/video/PixVerse", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos based on prompt and output_size.", inputs=[ IO.String.Input( "prompt", @@ -170,7 +142,7 @@ class PixverseTextToVideoNode(IO.ComfyNode): negative_prompt: str = None, pixverse_template: int = None, ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=False) + validate_string(prompt, strip_whitespace=False, min_length=1) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration if quality == PixverseQuality.res_1080p: @@ -179,18 +151,11 @@ class PixverseTextToVideoNode(IO.ComfyNode): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/video/text/generate", - method=HttpMethod.POST, - request_model=PixverseTextVideoRequest, - response_model=PixverseVideoResponse, - ), - request=PixverseTextVideoRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"), + response_model=PixverseVideoResponse, + data=PixverseTextVideoRequest( prompt=prompt, aspect_ratio=aspect_ratio, quality=quality, @@ -200,20 +165,14 @@ class PixverseTextToVideoNode(IO.ComfyNode): template_id=pixverse_template, seed=seed, ), - auth_kwargs=auth, ) - response_api = await operation.execute() - if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PixverseGenerationStatusResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), + response_model=PixverseGenerationStatusResponse, completed_statuses=[PixverseStatus.successful], failed_statuses=[ PixverseStatus.contents_moderation, @@ -221,30 +180,19 @@ class PixverseTextToVideoNode(IO.ComfyNode): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=auth, - node_id=cls.hidden.unique_id, - result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.Resp.url) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) class PixverseImageToVideoNode(IO.ComfyNode): - """ - Generates videos based on prompt and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="PixverseImageToVideoNode", display_name="PixVerse Image to Video", category="api node/video/PixVerse", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("image"), IO.String.Input( @@ -309,11 +257,7 @@ class PixverseImageToVideoNode(IO.ComfyNode): pixverse_template: int = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - img_id = await upload_image_to_pixverse(image, auth_kwargs=auth) + img_id = await upload_image_to_pixverse(cls, image) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -323,14 +267,11 @@ class PixverseImageToVideoNode(IO.ComfyNode): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/video/img/generate", - method=HttpMethod.POST, - request_model=PixverseImageVideoRequest, - response_model=PixverseVideoResponse, - ), - request=PixverseImageVideoRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"), + response_model=PixverseVideoResponse, + data=PixverseImageVideoRequest( img_id=img_id, prompt=prompt, quality=quality, @@ -340,20 +281,15 @@ class PixverseImageToVideoNode(IO.ComfyNode): template_id=pixverse_template, seed=seed, ), - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PixverseGenerationStatusResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), + response_model=PixverseGenerationStatusResponse, completed_statuses=[PixverseStatus.successful], failed_statuses=[ PixverseStatus.contents_moderation, @@ -361,30 +297,19 @@ class PixverseImageToVideoNode(IO.ComfyNode): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=auth, - node_id=cls.hidden.unique_id, - result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_I2V, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.Resp.url) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) class PixverseTransitionVideoNode(IO.ComfyNode): - """ - Generates videos based on prompt and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="PixverseTransitionVideoNode", display_name="PixVerse Transition Video", category="api node/video/PixVerse", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("first_frame"), IO.Image.Input("last_frame"), @@ -445,12 +370,8 @@ class PixverseTransitionVideoNode(IO.ComfyNode): negative_prompt: str = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth) - last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth) + first_frame_id = await upload_image_to_pixverse(cls, first_frame) + last_frame_id = await upload_image_to_pixverse(cls, last_frame) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -460,14 +381,11 @@ class PixverseTransitionVideoNode(IO.ComfyNode): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/video/transition/generate", - method=HttpMethod.POST, - request_model=PixverseTransitionVideoRequest, - response_model=PixverseVideoResponse, - ), - request=PixverseTransitionVideoRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"), + response_model=PixverseVideoResponse, + data=PixverseTransitionVideoRequest( first_frame_img=first_frame_id, last_frame_img=last_frame_id, prompt=prompt, @@ -477,20 +395,15 @@ class PixverseTransitionVideoNode(IO.ComfyNode): negative_prompt=negative_prompt if negative_prompt else None, seed=seed, ), - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PixverseGenerationStatusResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), + response_model=PixverseGenerationStatusResponse, completed_statuses=[PixverseStatus.successful], failed_statuses=[ PixverseStatus.contents_moderation, @@ -498,16 +411,9 @@ class PixverseTransitionVideoNode(IO.ComfyNode): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=auth, - node_id=cls.hidden.unique_id, - result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.Resp.url) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) class PixVerseExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index dee186cd6..e3440b946 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -8,9 +8,6 @@ from typing_extensions import override from comfy.utils import ProgressBar from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apinode_utils import ( - resize_mask_to_image, -) from comfy_api_nodes.apis.recraft_api import ( RecraftColor, RecraftColorChain, @@ -28,6 +25,7 @@ from comfy_api_nodes.util import ( ApiEndpoint, bytesio_to_image_tensor, download_url_as_bytesio, + resize_mask_to_image, sync_op, tensor_to_bytesio, validate_string, diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 0543d1d0e..2fdafbbfe 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -200,7 +200,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): ) -> IO.NodeOutput: validate_string(prompt, min_length=1) validate_image_dimensions(start_frame, max_width=7999, max_height=7999) - validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) download_urls = await upload_images_to_comfyapi( cls, @@ -290,7 +290,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): ) -> IO.NodeOutput: validate_string(prompt, min_length=1) validate_image_dimensions(start_frame, max_width=7999, max_height=7999) - validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) download_urls = await upload_images_to_comfyapi( cls, @@ -390,8 +390,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): validate_string(prompt, min_length=1) validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(end_frame, max_width=7999, max_height=7999) - validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) + validate_image_aspect_ratio(end_frame, (1, 2), (2, 1)) stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) download_urls = await upload_images_to_comfyapi( @@ -475,7 +475,7 @@ class RunwayTextToImageNode(IO.ComfyNode): reference_images = None if reference_image is not None: validate_image_dimensions(reference_image, max_width=7999, max_height=7999) - validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + validate_image_aspect_ratio(reference_image, (1, 2), (2, 1)) download_urls = await upload_images_to_comfyapi( cls, reference_image, diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 0e0572f8c..7a679f0d9 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -14,9 +14,9 @@ from comfy_api_nodes.util import ( poll_op, sync_op, upload_images_to_comfyapi, - validate_aspect_ratio_closeness, - validate_image_aspect_ratio_range, + validate_image_aspect_ratio, validate_image_dimensions, + validate_images_aspect_ratio_closeness, ) VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" @@ -114,7 +114,7 @@ async def execute_task( cls, ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), response_model=TaskStatusResponse, - status_extractor=lambda r: r.state.value, + status_extractor=lambda r: r.state, estimated_duration=estimated_duration, ) @@ -307,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode): ) -> IO.NodeOutput: if get_number_of_images(image) > 1: raise ValueError("Only one input image is allowed.") - validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) + validate_image_aspect_ratio(image, (1, 4), (4, 1)) payload = TaskCreationRequest( model_name=model, prompt=prompt, @@ -423,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): if a > 7: raise ValueError("Too many images, maximum allowed is 7.") for image in images: - validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) + validate_image_aspect_ratio(image, (1, 4), (4, 1)) validate_image_dimensions(image, min_width=128, min_height=128) payload = TaskCreationRequest( model_name=model, @@ -533,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): resolution: str, movement_amplitude: str, ) -> IO.NodeOutput: - validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) + validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) payload = TaskCreationRequest( model_name=model, prompt=prompt, diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 0cca2b59b..bbc71363a 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -14,6 +14,7 @@ from .conversions import ( downscale_image_tensor, image_tensor_pair_to_batch, pil_to_bytesio, + resize_mask_to_image, tensor_to_base64_string, tensor_to_bytesio, tensor_to_pil, @@ -34,12 +35,12 @@ from .upload_helpers import ( ) from .validation_utils import ( get_number_of_images, - validate_aspect_ratio_closeness, + validate_aspect_ratio_string, validate_audio_duration, validate_container_format_is_mp4, validate_image_aspect_ratio, - validate_image_aspect_ratio_range, validate_image_dimensions, + validate_images_aspect_ratio_closeness, validate_string, validate_video_dimensions, validate_video_duration, @@ -70,6 +71,7 @@ __all__ = [ "downscale_image_tensor", "image_tensor_pair_to_batch", "pil_to_bytesio", + "resize_mask_to_image", "tensor_to_base64_string", "tensor_to_bytesio", "tensor_to_pil", @@ -77,12 +79,12 @@ __all__ = [ "video_to_base64_string", # Validation utilities "get_number_of_images", - "validate_aspect_ratio_closeness", + "validate_aspect_ratio_string", "validate_audio_duration", "validate_container_format_is_mp4", "validate_image_aspect_ratio", - "validate_image_aspect_ratio_range", "validate_image_dimensions", + "validate_images_aspect_ratio_closeness", "validate_string", "validate_video_dimensions", "validate_video_duration", diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 9f4c90c5c..b59c2bd84 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -430,3 +430,24 @@ def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict: wav = torch.cat(frames, dim=1) # [C, T] wav = _f32_pcm(wav) return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} + + +def resize_mask_to_image( + mask: torch.Tensor, + image: torch.Tensor, + upscale_method="nearest-exact", + crop="disabled", + allow_gradient=True, + add_channel_dim=False, +): + """Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.""" + _, height, width, _ = image.shape + mask = mask.unsqueeze(-1) + mask = mask.movedim(-1, 1) + mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop) + mask = mask.movedim(1, -1) + if not add_channel_dim: + mask = mask.squeeze(-1) + if not allow_gradient: + mask = (mask > 0.5).float() + return mask diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index 22da05bc1..ec7006aed 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -37,63 +37,62 @@ def validate_image_dimensions( def validate_image_aspect_ratio( image: torch.Tensor, - min_aspect_ratio: Optional[float] = None, - max_aspect_ratio: Optional[float] = None, -): - width, height = get_image_dimensions(image) - aspect_ratio = width / height - - if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio: - raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}") - if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio: - raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}") - - -def validate_image_aspect_ratio_range( - image: torch.Tensor, - min_ratio: tuple[float, float], # e.g. (1, 4) - max_ratio: tuple[float, float], # e.g. (4, 1) + min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) + max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) *, strict: bool = True, # True -> (min, max); False -> [min, max] ) -> float: - a1, b1 = min_ratio - a2, b2 = max_ratio - if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0: - raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).") - lo, hi = (a1 / b1), (a2 / b2) - if lo > hi: - lo, hi = hi, lo - a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text + """Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked.""" w, h = get_image_dimensions(image) if w <= 0 or h <= 0: raise ValueError(f"Invalid image dimensions: {w}x{h}") ar = w / h - ok = (lo < ar < hi) if strict else (lo <= ar <= hi) - if not ok: - op = "<" if strict else "≤" - raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}") + _assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict) return ar -def validate_aspect_ratio_closeness( - start_img, - end_img, - min_rel: float, - max_rel: float, +def validate_images_aspect_ratio_closeness( + first_image: torch.Tensor, + second_image: torch.Tensor, + min_rel: float, # e.g. 0.8 + max_rel: float, # e.g. 1.25 *, - strict: bool = False, # True => exclusive, False => inclusive -) -> None: - w1, h1 = get_image_dimensions(start_img) - w2, h2 = get_image_dimensions(end_img) + strict: bool = False, # True -> (min, max); False -> [min, max] +) -> float: + """ + Validates that the two images' aspect ratios are 'close'. + The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1). + We require C <= limit, where limit = max(max_rel, 1.0 / min_rel). + + Returns the computed closeness factor C. + """ + w1, h1 = get_image_dimensions(first_image) + w2, h2 = get_image_dimensions(second_image) if min(w1, h1, w2, h2) <= 0: raise ValueError("Invalid image dimensions") ar1 = w1 / h1 ar2 = w2 / h2 - # Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1) closeness = max(ar1, ar2) / min(ar1, ar2) - limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25 + limit = max(max_rel, 1.0 / min_rel) if (closeness >= limit) if strict else (closeness > limit): - raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.") + raise ValueError( + f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, " + f"allowed range {min_rel}–{max_rel} (limit {limit:.2g})." + ) + return closeness + + +def validate_aspect_ratio_string( + aspect_ratio: str, + min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) + max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) + *, + strict: bool = False, # True -> (min, max); False -> [min, max] +) -> float: + """Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio.""" + ar = _parse_aspect_ratio_string(aspect_ratio) + _assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict) + return ar def validate_video_dimensions( @@ -183,3 +182,49 @@ def validate_container_format_is_mp4(video: VideoInput) -> None: container_format = video.get_container_format() if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: raise ValueError(f"Only MP4 container format supported. Got: {container_format}") + + +def _ratio_from_tuple(r: tuple[float, float]) -> float: + a, b = r + if a <= 0 or b <= 0: + raise ValueError(f"Ratios must be positive, got {a}:{b}.") + return a / b + + +def _assert_ratio_bounds( + ar: float, + *, + min_ratio: Optional[tuple[float, float]] = None, + max_ratio: Optional[tuple[float, float]] = None, + strict: bool = True, +) -> None: + """Validate a numeric aspect ratio against optional min/max ratio bounds.""" + lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None + hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None + + if lo is not None and hi is not None and lo > hi: + lo, hi = hi, lo # normalize order if caller swapped them + + if lo is not None: + if (ar <= lo) if strict else (ar < lo): + op = "<" if strict else "≤" + raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.") + if hi is not None: + if (ar >= hi) if strict else (ar > hi): + op = "<" if strict else "≤" + raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.") + + +def _parse_aspect_ratio_string(ar_str: str) -> float: + """Parse 'X:Y' with integer parts into a positive float ratio X/Y.""" + parts = ar_str.split(":") + if len(parts) != 2: + raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.") + try: + a = int(parts[0].strip()) + b = int(parts[1].strip()) + except ValueError as exc: + raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc + if a <= 0 or b <= 0: + raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.") + return a / b diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 566bc3f9c..b498f43e7 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,4 +1,9 @@ +import bisect +import gc import itertools +import psutil +import time +import torch from typing import Sequence, Mapping, Dict from comfy_execution.graph import DynamicPrompt from abc import ABC, abstractmethod @@ -188,6 +193,9 @@ class BasicCache: self._clean_cache() self._clean_subcaches() + def poll(self, **kwargs): + pass + def _set_immediate(self, node_id, value): assert self.initialized cache_key = self.cache_key_set.get_data_key(node_id) @@ -276,6 +284,9 @@ class NullCache: def clean_unused(self): pass + def poll(self, **kwargs): + pass + def get(self, node_id): return None @@ -336,3 +347,75 @@ class LRUCache(BasicCache): self._mark_used(child_id) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self + + +#Iterating the cache for usage analysis might be expensive, so if we trigger make sure +#to take a chunk out to give breathing space on high-node / low-ram-per-node flows. + +RAM_CACHE_HYSTERESIS = 1.1 + +#This is kinda in GB but not really. It needs to be non-zero for the below heuristic +#and as long as Multi GB models dwarf this it will approximate OOM scoring OK + +RAM_CACHE_DEFAULT_RAM_USAGE = 0.1 + +#Exponential bias towards evicting older workflows so garbage will be taken out +#in constantly changing setups. + +RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 + +class RAMPressureCache(LRUCache): + + def __init__(self, key_class): + super().__init__(key_class, 0) + self.timestamps = {} + + def clean_unused(self): + self._clean_subcaches() + + def set(self, node_id, value): + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() + super().set(node_id, value) + + def get(self, node_id): + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() + return super().get(node_id) + + def poll(self, ram_headroom): + def _ram_gb(): + return psutil.virtual_memory().available / (1024**3) + + if _ram_gb() > ram_headroom: + return + gc.collect() + if _ram_gb() > ram_headroom: + return + + clean_list = [] + + for key, (outputs, _), in self.cache.items(): + oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) + + ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE + def scan_list_for_ram_usage(outputs): + nonlocal ram_usage + for output in outputs: + if isinstance(output, list): + scan_list_for_ram_usage(output) + elif isinstance(output, torch.Tensor) and output.device.type == 'cpu': + #score Tensors at a 50% discount for RAM usage as they are likely to + #be high value intermediates + ram_usage += (output.numel() * output.element_size()) * 0.5 + elif hasattr(output, "get_ram_usage"): + ram_usage += output.get_ram_usage() + scan_list_for_ram_usage(outputs) + + oom_score *= ram_usage + #In the case where we have no information on the node ram usage at all, + #break OOM score ties on the last touch timestamp (pure LRU) + bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) + + while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: + _, _, key = clean_list.pop() + del self.cache[key] + gc.collect() diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 341c9735d..0d811e354 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -209,10 +209,15 @@ class ExecutionList(TopologicalSort): self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) - def get_output_cache(self, from_node_id, to_node_id): + def get_cache(self, from_node_id, to_node_id): if not to_node_id in self.execution_cache: return None - return self.execution_cache[to_node_id].get(from_node_id) + value = self.execution_cache[to_node_id].get(from_node_id) + if value is None: + return None + #Write back to the main cache on touch. + self.output_cache.set(from_node_id, value) + return value def cache_update(self, node_id, value): if node_id in self.execution_cache_listeners: diff --git a/execution.py b/execution.py index 20e106213..17c77beab 100644 --- a/execution.py +++ b/execution.py @@ -21,6 +21,7 @@ from comfy_execution.caching import ( NullCache, HierarchicalCache, LRUCache, + RAMPressureCache, ) from comfy_execution.graph import ( DynamicPrompt, @@ -88,49 +89,56 @@ class IsChangedCache: return self.is_changed[node_id] +class CacheEntry(NamedTuple): + ui: dict + outputs: list + + class CacheType(Enum): CLASSIC = 0 LRU = 1 NONE = 2 + RAM_PRESSURE = 3 class CacheSet: - def __init__(self, cache_type=None, cache_size=None): + def __init__(self, cache_type=None, cache_args={}): if cache_type == CacheType.NONE: self.init_null_cache() logging.info("Disabling intermediate node cache.") + elif cache_type == CacheType.RAM_PRESSURE: + cache_ram = cache_args.get("ram", 16.0) + self.init_ram_cache(cache_ram) + logging.info("Using RAM pressure cache.") elif cache_type == CacheType.LRU: - if cache_size is None: - cache_size = 0 + cache_size = cache_args.get("lru", 0) self.init_lru_cache(cache_size) logging.info("Using LRU cache") else: self.init_classic_cache() - self.all = [self.outputs, self.ui, self.objects] + self.all = [self.outputs, self.objects] # Performs like the old cache -- dump data ASAP def init_classic_cache(self): self.outputs = HierarchicalCache(CacheKeySetInputSignature) - self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_lru_cache(self, cache_size): self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.objects = HierarchicalCache(CacheKeySetID) + + def init_ram_cache(self, min_headroom): + self.outputs = RAMPressureCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): self.outputs = NullCache() - #The UI cache is expected to be iterable at the end of each workflow - #so it must cache at least a full workflow. Use Heirachical - self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = NullCache() def recursive_debug_dump(self): result = { "outputs": self.outputs.recursive_debug_dump(), - "ui": self.ui.recursive_debug_dump(), } return result @@ -157,14 +165,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= if execution_list is None: mark_missing() continue # This might be a lazily-evaluated input - cached_output = execution_list.get_output_cache(input_unique_id, unique_id) - if cached_output is None: + cached = execution_list.get_cache(input_unique_id, unique_id) + if cached is None or cached.outputs is None: mark_missing() continue - if output_index >= len(cached_output): + if output_index >= len(cached.outputs): mark_missing() continue - obj = cached_output[output_index] + obj = cached.outputs[output_index] input_data_all[x] = obj elif input_category is not None: input_data_all[x] = [input_data] @@ -393,7 +401,7 @@ def format_value(x): else: return str(x) -async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes): +async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -401,12 +409,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if caches.outputs.get(unique_id) is not None: + cached = caches.outputs.get(unique_id) + if cached is not None: if server.client_id is not None: - cached_output = caches.ui.get(unique_id) or {} - server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) + cached_ui = cached.ui or {} + server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) + if cached.ui is not None: + ui_outputs[unique_id] = cached.ui get_progress_state().finish_progress(unique_id) - execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) + execution_list.cache_update(unique_id, cached) return (ExecutionResult.SUCCESS, None, None) input_data_all = None @@ -436,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] - for o in node_output: + node_cached = execution_list.get_cache(source_node, unique_id) + for o in node_cached.outputs[source_output]: resolved_output.append(o) else: @@ -507,7 +518,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: - caches.ui.set(unique_id, { + ui_outputs[unique_id] = { "meta": { "node_id": unique_id, "display_node": display_node_id, @@ -515,7 +526,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, "real_node_id": real_node_id, }, "output": output_ui - }) + } if server.client_id is not None: server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) if has_subgraph: @@ -554,8 +565,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) - caches.outputs.set(unique_id, output_data) - execution_list.cache_update(unique_id, output_data) + cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) + execution_list.cache_update(unique_id, cache_entry) + caches.outputs.set(unique_id, cache_entry) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") @@ -600,14 +612,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server, cache_type=False, cache_size=None): - self.cache_size = cache_size + def __init__(self, server, cache_type=False, cache_args=None): + self.cache_args = cache_args self.cache_type = cache_type self.server = server self.reset() def reset(self): - self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) self.status_messages = [] self.success = True @@ -682,6 +694,7 @@ class PromptExecutor: broadcast=False) pending_subgraph_results = {} pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} executed = set() execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) current_outputs = self.caches.outputs.all_node_ids() @@ -695,7 +708,7 @@ class PromptExecutor: break assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) @@ -704,18 +717,16 @@ class PromptExecutor: execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) else: # Only execute when the while-loop ends without break self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) ui_outputs = {} meta_outputs = {} - all_node_ids = self.caches.ui.all_node_ids() - for node_id in all_node_ids: - ui_info = self.caches.ui.get(node_id) - if ui_info is not None: - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] self.history_result = { "outputs": ui_outputs, "meta": meta_outputs, diff --git a/main.py b/main.py index 953b48a37..240c1255f 100644 --- a/main.py +++ b/main.py @@ -198,10 +198,12 @@ def prompt_worker(q, server_instance): cache_type = execution.CacheType.CLASSIC if args.cache_lru > 0: cache_type = execution.CacheType.LRU + elif args.cache_ram > 0: + cache_type = execution.CacheType.RAM_PRESSURE elif args.cache_none: cache_type = execution.CacheType.NONE - e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } ) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0