From 42974a448c39af50c5f72d8c70267f9fe2971cd2 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 7 Aug 2025 14:54:09 -0700 Subject: [PATCH 01/71] _ui.py import torchaudio safety check (#9234) * Added safety around torchaudio import in _ui.py * Trusted cursor too much, fixed torchaudio bool --- comfy_api/latest/_ui.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 6b8a39d58..61597038f 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -9,7 +9,11 @@ from typing import Type import av import numpy as np import torch -import torchaudio +try: + import torchaudio + TORCH_AUDIO_AVAILABLE = True +except ImportError: + TORCH_AUDIO_AVAILABLE = False from PIL import Image as PILImage from PIL.PngImagePlugin import PngInfo @@ -302,6 +306,8 @@ class AudioSaveHelper: # Resample if necessary if sample_rate != audio["sample_rate"]: + if not TORCH_AUDIO_AVAILABLE: + raise Exception("torchaudio is not available; cannot resample audio.") waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate) # Create output with specified format From bf2a1b5b1ef72b240454f3ac44f5209af45efe00 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 8 Aug 2025 06:37:50 +0300 Subject: [PATCH 02/71] async API nodes (#9129) * converted API nodes to async * converted BFL API nodes to async * fixed client bug; converted gemini, ideogram, minimax * fixed client bug; converted openai nodes * fixed client bug; converted moonvalley, pika nodes * fixed client bug; converted kling, luma nodes * converted pixverse, rodin nodes * converted tripo, veo2 * converted recraft nodes * add lost log_request_response call --- comfy_api_nodes/apinode_utils.py | 152 ++--- comfy_api_nodes/apis/client.py | 901 +++++++++++----------------- comfy_api_nodes/nodes_bfl.py | 134 +++-- comfy_api_nodes/nodes_gemini.py | 4 +- comfy_api_nodes/nodes_ideogram.py | 22 +- comfy_api_nodes/nodes_kling.py | 130 ++-- comfy_api_nodes/nodes_luma.py | 70 ++- comfy_api_nodes/nodes_minimax.py | 14 +- comfy_api_nodes/nodes_moonvalley.py | 38 +- comfy_api_nodes/nodes_openai.py | 28 +- comfy_api_nodes/nodes_pika.py | 63 +- comfy_api_nodes/nodes_pixverse.py | 46 +- comfy_api_nodes/nodes_recraft.py | 44 +- comfy_api_nodes/nodes_rodin.py | 147 ++--- comfy_api_nodes/nodes_runway.py | 54 +- comfy_api_nodes/nodes_stability.py | 23 +- comfy_api_nodes/nodes_tripo.py | 69 ++- comfy_api_nodes/nodes_veo2.py | 15 +- 18 files changed, 878 insertions(+), 1076 deletions(-) diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index 788e2803f..f953f86df 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -1,4 +1,5 @@ from __future__ import annotations +import aiohttp import io import logging import mimetypes @@ -21,7 +22,6 @@ from server import PromptServer import numpy as np from PIL import Image -import requests import torch import math import base64 @@ -30,7 +30,7 @@ from io import BytesIO import av -def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile: +async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output. Args: @@ -39,7 +39,7 @@ def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFr Returns: A Comfy node `VIDEO` output. """ - video_io = download_url_to_bytesio(video_url, timeout) + video_io = await download_url_to_bytesio(video_url, timeout) if video_io is None: error_msg = f"Failed to download video from {video_url}" logging.error(error_msg) @@ -62,7 +62,7 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: return s -def validate_and_cast_response( +async def validate_and_cast_response( response, timeout: int = None, node_id: Union[str, None] = None ) -> torch.Tensor: """Validates and casts a response to a torch.Tensor. @@ -86,35 +86,24 @@ def validate_and_cast_response( image_tensors: list[torch.Tensor] = [] # Process each image in the data array - for image_data in data: - image_url = image_data.url - b64_data = image_data.b64_json + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + for img_data in data: + img_bytes: bytes + if img_data.b64_json: + img_bytes = base64.b64decode(img_data.b64_json) + elif img_data.url: + if node_id: + PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id) + async with session.get(img_data.url) as resp: + if resp.status != 200: + raise ValueError("Failed to download generated image") + img_bytes = await resp.read() + else: + raise ValueError("Invalid image payload – neither URL nor base64 data present.") - if not image_url and not b64_data: - raise ValueError("No image was generated in the response") - - if b64_data: - img_data = base64.b64decode(b64_data) - img = Image.open(io.BytesIO(img_data)) - - elif image_url: - if node_id: - PromptServer.instance.send_progress_text( - f"Result URL: {image_url}", node_id - ) - img_response = requests.get(image_url, timeout=timeout) - if img_response.status_code != 200: - raise ValueError("Failed to download the image") - img = Image.open(io.BytesIO(img_response.content)) - - img = img.convert("RGBA") - - # Convert to numpy array, normalize to float32 between 0 and 1 - img_array = np.array(img).astype(np.float32) / 255.0 - img_tensor = torch.from_numpy(img_array) - - # Add to list of tensors - image_tensors.append(img_tensor) + pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA") + arr = np.asarray(pil_img).astype(np.float32) / 255.0 + image_tensors.append(torch.from_numpy(arr)) return torch.stack(image_tensors, dim=0) @@ -175,7 +164,7 @@ def mimetype_to_extension(mime_type: str) -> str: return mime_type.split("/")[-1].lower() -def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: +async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: """Downloads content from a URL using requests and returns it as BytesIO. Args: @@ -185,9 +174,11 @@ def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: Returns: BytesIO object containing the downloaded content. """ - response = requests.get(url, stream=True, timeout=timeout) - response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) - return BytesIO(response.content) + timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None + async with aiohttp.ClientSession(timeout=timeout_cfg) as session: + async with session.get(url) as resp: + resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) + return BytesIO(await resp.read()) def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: @@ -210,15 +201,15 @@ def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch return torch.from_numpy(image_array).unsqueeze(0) -def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: +async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" - image_bytesio = download_url_to_bytesio(url, timeout) + image_bytesio = await download_url_to_bytesio(url, timeout) return bytesio_to_image_tensor(image_bytesio) -def process_image_response(response: requests.Response) -> torch.Tensor: +def process_image_response(response_content: bytes | str) -> torch.Tensor: """Uses content from a Response object and converts it to a torch.Tensor""" - return bytesio_to_image_tensor(BytesIO(response.content)) + return bytesio_to_image_tensor(BytesIO(response_content)) def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: @@ -336,10 +327,10 @@ def text_filepath_to_data_uri(filepath: str) -> str: return f"data:{mime_type};base64,{base64_string}" -def upload_file_to_comfyapi( +async def upload_file_to_comfyapi( file_bytes_io: BytesIO, filename: str, - upload_mime_type: str, + upload_mime_type: Optional[str], auth_kwargs: Optional[dict[str, str]] = None, ) -> str: """ @@ -354,7 +345,10 @@ def upload_file_to_comfyapi( Returns: The download URL for the uploaded file. """ - request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) + if upload_mime_type is None: + request_object = UploadRequest(file_name=filename) + else: + request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) operation = SynchronousOperation( endpoint=ApiEndpoint( path="/customers/storage", @@ -366,12 +360,8 @@ def upload_file_to_comfyapi( auth_kwargs=auth_kwargs, ) - response: UploadResponse = operation.execute() - upload_response = ApiClient.upload_file( - response.upload_url, file_bytes_io, content_type=upload_mime_type - ) - upload_response.raise_for_status() - + response: UploadResponse = await operation.execute() + await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type) return response.download_url @@ -399,7 +389,7 @@ def video_to_base64_string( return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") -def upload_video_to_comfyapi( +async def upload_video_to_comfyapi( video: VideoInput, auth_kwargs: Optional[dict[str, str]] = None, container: VideoContainer = VideoContainer.MP4, @@ -439,9 +429,7 @@ def upload_video_to_comfyapi( video.save_to(video_bytes_io, format=container, codec=codec) video_bytes_io.seek(0) - return upload_file_to_comfyapi( - video_bytes_io, filename, upload_mime_type, auth_kwargs - ) + return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs) def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: @@ -501,7 +489,7 @@ def audio_ndarray_to_bytesio( return audio_bytes_io -def upload_audio_to_comfyapi( +async def upload_audio_to_comfyapi( audio: AudioInput, auth_kwargs: Optional[dict[str, str]] = None, container_format: str = "mp4", @@ -527,7 +515,7 @@ def upload_audio_to_comfyapi( audio_data_np, sample_rate, container_format, codec_name ) - return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) + return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) def audio_to_base64_string( @@ -544,7 +532,7 @@ def audio_to_base64_string( return base64.b64encode(audio_bytes).decode("utf-8") -def upload_images_to_comfyapi( +async def upload_images_to_comfyapi( image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str, str]] = None, @@ -561,55 +549,15 @@ def upload_images_to_comfyapi( mime_type: Optional MIME type for the image. """ # if batch, try to upload each file if max_images is greater than 0 - idx_image = 0 download_urls: list[str] = [] is_batch = len(image.shape) > 3 - batch_length = 1 - if is_batch: - batch_length = image.shape[0] - while True: - curr_image = image - if len(image.shape) > 3: - curr_image = image[idx_image] - # get BytesIO version of image - img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type) - # first, request upload/download urls from comfy API - if not mime_type: - request_object = UploadRequest(file_name=img_binary.name) - else: - request_object = UploadRequest( - file_name=img_binary.name, content_type=mime_type - ) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/customers/storage", - method=HttpMethod.POST, - request_model=UploadRequest, - response_model=UploadResponse, - ), - request=request_object, - auth_kwargs=auth_kwargs, - ) - response = operation.execute() + batch_len = image.shape[0] if is_batch else 1 - upload_response = ApiClient.upload_file( - response.upload_url, img_binary, content_type=mime_type - ) - # verify success - try: - upload_response.raise_for_status() - except requests.exceptions.HTTPError as e: - raise ValueError(f"Could not upload one or more images: {e}") from e - # add download_url to list - download_urls.append(response.download_url) - - idx_image += 1 - # stop uploading additional files if done - if is_batch and max_images > 0: - if idx_image >= max_images: - break - if idx_image >= batch_length: - break + for idx in range(min(batch_len, max_images)): + tensor = image[idx] if is_batch else image + img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs) + download_urls.append(url) return download_urls diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index 2a4bac88b..4ad0b783b 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -43,7 +43,7 @@ operation = ApiOperation( endpoint=user_info_endpoint, request=request ) -user_profile = operation.execute(client=api_client) # Returns immediately with the result +user_profile = await operation.execute(client=api_client) # Returns immediately with the result # Example 2: Asynchronous API Operation with Polling @@ -87,18 +87,19 @@ operation = PollingOperation( ) # This will make the initial request and then poll until completion -result = operation.execute(client=api_client) # Returns the final ImageGenerationResult when done +result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done """ from __future__ import annotations +import aiohttp +import asyncio import logging -import time import io import socket +from aiohttp.client_exceptions import ClientError, ClientResponseError from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple from enum import Enum import json -import requests from urllib.parse import urljoin, urlparse from pydantic import BaseModel, Field import uuid # For generating unique operation IDs @@ -174,6 +175,7 @@ class ApiClient: retry_delay: float = 1.0, retry_backoff_factor: float = 2.0, retry_status_codes: Optional[Tuple[int, ...]] = None, + session: Optional[aiohttp.ClientSession] = None, ): self.base_url = base_url self.auth_token = auth_token @@ -186,13 +188,16 @@ class ApiClient: # Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests), # 500, 502, 503, 504 (Server Errors) self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) + self._session: Optional[aiohttp.ClientSession] = session + self._owns_session = session is None # Track if we have to close it - def _generate_operation_id(self, path: str) -> str: + @staticmethod + def _generate_operation_id(path: str) -> str: """Generates a unique operation ID for logging.""" return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" + @staticmethod def _create_json_payload_args( - self, data: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: @@ -203,31 +208,53 @@ class ApiClient: def _create_form_data_args( self, - data: Dict[str, Any], - files: Dict[str, Any], + data: Dict[str, Any] | None, + files: Dict[str, Any] | None, headers: Optional[Dict[str, str]] = None, - multipart_parser = None, + multipart_parser: Callable | None = None, ) -> Dict[str, Any]: if headers and "Content-Type" in headers: del headers["Content-Type"] - if multipart_parser: + if multipart_parser and data: data = multipart_parser(data) - return { - "data": data, - "files": files, - "headers": headers, - } + form = aiohttp.FormData(default_to_multipart=True) + if data: # regular text fields + for k, v in data.items(): + if v is None: + continue # aiohttp fails to serialize "None" values + # aiohttp expects strings or bytes; convert enums etc. + form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + if files: + file_iter = files if isinstance(files, list) else files.items() + for field_name, file_obj in file_iter: + if file_obj is None: + continue # aiohttp fails to serialize "None" values + # file_obj can be (filename, bytes/io.BytesIO, content_type) tuple + if isinstance(file_obj, tuple): + filename, file_value, content_type = self._unpack_tuple(file_obj) + else: + file_value = file_obj + filename = getattr(file_obj, "name", field_name) + content_type = "application/octet-stream" + + form.add_field( + name=field_name, + value=file_value, + filename=filename, + content_type=content_type, + ) + return {"data": form, "headers": headers or {}} + + @staticmethod def _create_urlencoded_form_data_args( - self, data: Dict[str, Any], headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: headers = headers or {} headers["Content-Type"] = "application/x-www-form-urlencoded" - return { "data": data, "headers": headers, @@ -244,7 +271,7 @@ class ApiClient: return headers - def _check_connectivity(self, target_url: str) -> Dict[str, bool]: + async def _check_connectivity(self, target_url: str) -> Dict[str, bool]: """ Check connectivity to determine if network issues are local or server-related. @@ -258,52 +285,39 @@ class ApiClient: "internet_accessible": False, "api_accessible": False, "is_local_issue": False, - "is_api_issue": False + "is_api_issue": False, } + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp: + results["internet_accessible"] = resp.status < 500 + except (ClientError, asyncio.TimeoutError, socket.gaierror): + results["is_local_issue"] = True + return results # cannot reach the internet – early exit - # First check basic internet connectivity using a reliable external site - try: - # Use a reliable external domain for checking basic connectivity - check_response = requests.get("https://www.google.com", - timeout=5.0, - verify=self.verify_ssl) - if check_response.status_code < 500: - results["internet_accessible"] = True - except (requests.RequestException, socket.error): - results["internet_accessible"] = False - results["is_local_issue"] = True - return results - - # Now check API server connectivity - try: - # Extract domain from the target URL to do a simpler health check - parsed_url = urlparse(target_url) - api_base = f"{parsed_url.scheme}://{parsed_url.netloc}" - - # Try to reach the API domain - api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl) - if api_response.status_code < 500: - results["api_accessible"] = True - else: - results["api_accessible"] = False - results["is_api_issue"] = True - except requests.RequestException: - results["api_accessible"] = False - # If we can reach the internet but not the API, it's an API issue - results["is_api_issue"] = True + # Now check API health endpoint + parsed = urlparse(target_url) + health_url = f"{parsed.scheme}://{parsed.netloc}/health" + try: + async with session.get(health_url, ssl=self.verify_ssl) as resp: + results["api_accessible"] = resp.status < 500 + except ClientError: + pass # leave as False + results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] return results - def request( + async def request( self, method: str, path: str, params: Optional[Dict[str, Any]] = None, data: Optional[Dict[str, Any]] = None, - files: Optional[Dict[str, Any]] = None, + files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, headers: Optional[Dict[str, str]] = None, content_type: str = "application/json", - multipart_parser: Callable = None, + multipart_parser: Callable | None = None, retry_count: int = 0, # Used internally for tracking retries ) -> Dict[str, Any]: """ @@ -327,18 +341,19 @@ class ApiClient: ApiServerError: If the API server is unreachable but internet is working Exception: For other request failures """ - # Use urljoin but ensure path is relative to avoid absolute path behavior - relative_path = path.lstrip('/') + + # Build full URL and merge headers + relative_path = path.lstrip("/") url = urljoin(self.base_url, relative_path) - self.check_auth(self.auth_token, self.comfy_api_key) - # Combine default headers with any provided headers + self._check_auth(self.auth_token, self.comfy_api_key) + request_headers = self.get_headers() if headers: request_headers.update(headers) - - # Let requests handle the content type when files are present. if files: - del request_headers["Content-Type"] + request_headers.pop("Content-Type", None) + if params: + params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values logging.debug(f"[DEBUG] Request Headers: {request_headers}") logging.debug(f"[DEBUG] Files: {files}") @@ -346,11 +361,9 @@ class ApiClient: logging.debug(f"[DEBUG] Data: {data}") if content_type == "application/x-www-form-urlencoded": - payload_args = self._create_urlencoded_form_data_args(data, request_headers) + payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers) elif content_type == "multipart/form-data": - payload_args = self._create_form_data_args( - data, files, request_headers, multipart_parser - ) + payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser) else: payload_args = self._create_json_payload_args(data, request_headers) @@ -361,220 +374,67 @@ class ApiClient: request_url=url, request_headers=request_headers, request_params=params, - request_data=data if content_type == "application/json" else "[form-data or other]" + request_data=data if content_type == "application/json" else "[form-data or other]", ) + session = await self._get_session() try: - response = requests.request( - method=method, - url=url, + async with session.request( + method, + url, params=params, - timeout=self.timeout, - verify=self.verify_ssl, + ssl=self.verify_ssl, **payload_args, - ) + ) as resp: + if resp.status >= 400: + try: + error_data = await resp.json() + except (aiohttp.ContentTypeError, json.JSONDecodeError): + error_data = await resp.text() - # Check if we should retry based on status code - if (response.status_code in self.retry_status_codes and - retry_count < self.max_retries): + return await self._handle_http_error( + ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data), + operation_id, + method, + url, + params, + data, + files, + headers, + content_type, + multipart_parser, + retry_count=retry_count, + response_content=error_data, + ) - # Calculate delay with exponential backoff - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - - logging.warning( - f"Request failed with status {response.status_code}. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - - # Raise exception for error status codes - response.raise_for_status() - - # Log successful response - response_content_to_log = response.content - try: - # Attempt to parse JSON for prettier logging, fallback to raw content - response_content_to_log = response.json() - except json.JSONDecodeError: - pass # Keep as bytes/str if not JSON - - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, # Pass request details again for context in log - request_url=url, - response_status_code=response.status_code, - response_headers=dict(response.headers), - response_content=response_content_to_log - ) - - except requests.ConnectionError as e: - error_message = f"ConnectionError: {str(e)}" - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - error_message=error_message - ) - # Only perform connectivity check if we've exhausted all retries - if retry_count >= self.max_retries: - # Check connectivity to determine if it's a local or API issue - connectivity = self._check_connectivity(self.base_url) - - if connectivity["is_local_issue"]: - raise LocalNetworkError( - "Unable to connect to the API server due to local network issues. " - "Please check your internet connection and try again." - ) from e - elif connectivity["is_api_issue"]: - raise ApiServerError( - f"The API server at {self.base_url} is currently unreachable. " - f"The service may be experiencing issues. Please try again later." - ) from e - - # If we haven't exhausted retries yet, retry the request - if retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - f"Connection error: {str(e)}. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - - # If we've exhausted retries and didn't identify the specific issue, - # raise a generic exception - final_error_message = ( - f"Unable to connect to the API server after {self.max_retries} attempts. " - f"Please check your internet connection or try again later." - ) - request_logger.log_request_response( # Log final failure - operation_id=operation_id, - request_method=method, request_url=url, - error_message=final_error_message - ) - raise Exception(final_error_message) from e - - except requests.Timeout as e: - error_message = f"Timeout: {str(e)}" - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, request_url=url, - error_message=error_message - ) - # Retry timeouts if we haven't exhausted retries - if retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - f"Request timed out. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - final_error_message = ( - f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. " - f"The server might be experiencing high load or the operation is taking longer than expected." - ) - request_logger.log_request_response( # Log final failure - operation_id=operation_id, - request_method=method, request_url=url, - error_message=final_error_message - ) - raise Exception(final_error_message) from e - - except requests.HTTPError as e: - status_code = e.response.status_code if hasattr(e, "response") else None - original_error_message = f"HTTP Error: {str(e)}" - error_content_for_log = None - if hasattr(e, "response") and e.response is not None: - error_content_for_log = e.response.content + # Success – parse JSON (safely) and log try: - error_content_for_log = e.response.json() - except json.JSONDecodeError: - pass + payload = await resp.json() + response_content_to_log = payload + except (aiohttp.ContentTypeError, json.JSONDecodeError): + payload = {} + response_content_to_log = await resp.text() - - # Try to extract detailed error message from JSON response for user display - # but log the full error content. - user_display_error_message = original_error_message - - try: - if hasattr(e, "response") and e.response is not None and e.response.content: - error_json = e.response.json() - if "error" in error_json and "message" in error_json["error"]: - user_display_error_message = f"API Error: {error_json['error']['message']}" - if "type" in error_json["error"]: - user_display_error_message += f" (Type: {error_json['error']['type']})" - elif isinstance(error_json, dict): # Handle cases where error is just a JSON dict - user_display_error_message = f"API Error: {json.dumps(error_json)}" - else: # Non-dict JSON error - user_display_error_message = f"API Error: {str(error_json)}" - except json.JSONDecodeError: - # If not JSON, use the raw content if it's not too long, or a summary - if hasattr(e, "response") and e.response is not None and e.response.content: - raw_content = e.response.content.decode(errors='ignore') - if len(raw_content) < 200: # Arbitrary limit for display - user_display_error_message = f"API Error (raw): {raw_content}" - else: - user_display_error_message = f"API Error (raw, status {status_code})" - - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, request_url=url, - response_status_code=status_code, - response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None, - response_content=error_content_for_log, - error_message=original_error_message # Log the original exception string as error - ) - - logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})") - if hasattr(e, "response") and e.response is not None and e.response.content: - logging.debug(f"[DEBUG] Response content: {e.response.content}") - - # Retry if the status code is in our retry list and we haven't exhausted retries - if (status_code in self.retry_status_codes and - retry_count < self.max_retries): - - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - f"HTTP error {status_code}. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=response_content_to_log, ) - time.sleep(delay) - return self.request( - method=method, - path=path, + return payload + + except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: + # Treat as *connection* problem – optionally retry, else escalate + if retry_count < self.max_retries: + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1, + self.max_retries, str(e)) + await asyncio.sleep(delay) + return await self.request( + method, + path, params=params, data=data, files=files, @@ -583,40 +443,34 @@ class ApiClient: multipart_parser=multipart_parser, retry_count=retry_count + 1, ) + # One final connectivity check for diagnostics + connectivity = await self._check_connectivity(self.base_url) + if connectivity["is_local_issue"]: + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + raise ApiServerError( + f"The API server at {self.base_url} is currently unreachable. " + f"The service may be experiencing issues. Please try again later." + ) from e - # Specific error messages for common status codes for user display - if status_code == 401: - user_display_error_message = "Unauthorized: Please login first to use this node." - elif status_code == 402: - user_display_error_message = "Payment Required: Please add credits to your account to use this node." - elif status_code == 409: - user_display_error_message = "There is a problem with your account. Please contact support@comfy.org." - elif status_code == 429: - user_display_error_message = "Rate Limit Exceeded: Please try again later." - # else, user_display_error_message remains as parsed from response or original HTTPError string - - raise Exception(user_display_error_message) # Raise with the user-friendly message - - # Parse and return JSON response - if response.content: - return response.json() - return {} - - def check_auth(self, auth_token, comfy_api_key): + @staticmethod + def _check_auth(auth_token, comfy_api_key): """Verify that an auth token is present or comfy_api_key is present""" if auth_token is None and comfy_api_key is None: raise Exception("Unauthorized: Please login first to use this node.") return auth_token or comfy_api_key @staticmethod - def upload_file( + async def upload_file( upload_url: str, file: io.BytesIO | str, content_type: str | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff_factor: float = 2.0, - ): + ) -> aiohttp.ClientResponse: """Upload a file to the API with retry logic. Args: @@ -627,112 +481,167 @@ class ApiClient: retry_delay: Initial delay between retries in seconds retry_backoff_factor: Multiplier for the delay after each retry """ - headers = {} + headers: Dict[str, str] = {} + skip_auto_headers: set[str] = set() if content_type: headers["Content-Type"] = content_type + else: + # tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status. + skip_auto_headers.add("Content-Type") - # Prepare the file data + # Extract file bytes if isinstance(file, io.BytesIO): - file.seek(0) # Ensure we're at the start of the file + file.seek(0) data = file.read() elif isinstance(file, str): with open(file, "rb") as f: data = f.read() else: - raise ValueError("File must be either a BytesIO object or a file path string") + raise ValueError("File must be BytesIO or str path") - # Try the upload with retries - last_exception = None - operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" # Simplified ID for uploads - - # Log initial attempt (without full file data for brevity) + operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" request_logger.log_request_response( operation_id=operation_id, request_method="PUT", request_url=upload_url, request_headers=headers, - request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]" + request_data=f"[File data {len(data)} bytes]", ) - for retry_attempt in range(max_retries + 1): + delay = retry_delay + for attempt in range(max_retries + 1): try: - response = requests.put(upload_url, data=data, headers=headers) - response.raise_for_status() + timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.put( + upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers, + ) as resp: + resp.raise_for_status() + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content="File uploaded successfully.", + ) + return resp + except (ClientError, asyncio.TimeoutError) as e: request_logger.log_request_response( operation_id=operation_id, - request_method="PUT", request_url=upload_url, # For context - response_status_code=response.status_code, - response_headers=dict(response.headers), - response_content="File uploaded successfully." # Or response.text if available + request_method="PUT", + request_url=upload_url, + response_status_code=e.status if hasattr(e, "status") else None, + response_headers=dict(e.headers) if getattr(e, "headers") else None, + response_content=None, + error_message=f"{type(e).__name__}: {str(e)}", ) - return response - - except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e: - last_exception = e - error_message_for_log = f"{type(e).__name__}: {str(e)}" - response_content_for_log = None - status_code_for_log = None - headers_for_log = None - - if hasattr(e, 'response') and e.response is not None: - status_code_for_log = e.response.status_code - headers_for_log = dict(e.response.headers) - try: - response_content_for_log = e.response.json() - except json.JSONDecodeError: - response_content_for_log = e.response.content - - - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", request_url=upload_url, - response_status_code=status_code_for_log, - response_headers=headers_for_log, - response_content=response_content_for_log, - error_message=error_message_for_log - ) - - if retry_attempt < max_retries: - delay = retry_delay * (retry_backoff_factor ** retry_attempt) + if attempt < max_retries: logging.warning( - f"File upload failed: {str(e)}. " - f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})" + "Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e) ) - time.sleep(delay) + await asyncio.sleep(delay) + delay *= retry_backoff_factor else: - break # Max retries reached + raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e - # If we've exhausted all retries, determine the final error type and raise - final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}" - try: - # Check basic internet connectivity - check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) # Assuming verify=True is desired - if check_response.status_code >= 500: # Google itself has an issue (rare) - final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed " - f"(status {check_response.status_code}). Original error: {str(last_exception)}") - # Not raising LocalNetworkError here as Google itself might be down. - # If Google is reachable, the issue is likely with the upload server or a more specific local problem - # not caught by a simple Google ping (e.g., DNS for the specific upload URL, firewall). - # The original last_exception is probably most relevant. + async def _handle_http_error( + self, + exc: ClientResponseError, + operation_id: str, + *req_meta, + retry_count: int, + response_content: dict | str = "", + ) -> Dict[str, Any]: + status_code = exc.status + if status_code == 401: + user_friendly = "Unauthorized: Please login first to use this node." + elif status_code == 402: + user_friendly = "Payment Required: Please add credits to your account to use this node." + elif status_code == 409: + user_friendly = "There is a problem with your account. Please contact support@comfy.org." + elif status_code == 429: + user_friendly = "Rate Limit Exceeded: Please try again later." + else: + if isinstance(response_content, dict): + if "error" in response_content and "message" in response_content["error"]: + user_friendly = f"API Error: {response_content['error']['message']}" + if "type" in response_content["error"]: + user_friendly += f" (Type: {response_content['error']['type']})" + else: # Handle cases where error is just a JSON dict with unknown format + user_friendly = f"API Error: {json.dumps(response_content)}" + else: + if len(response_content) < 200: # Arbitrary limit for display + user_friendly = f"API Error (raw): {response_content}" + else: + user_friendly = f"API Error (raw, status {response_content})" - except (requests.RequestException, socket.error) as conn_check_exc: - # Could not reach Google, likely a local network issue - final_error_message = (f"Failed to upload file due to network connectivity issues " - f"(cannot reach Google: {str(conn_check_exc)}). " - f"Original upload error: {str(last_exception)}") - request_logger.log_request_response( # Log final failure reason - operation_id=operation_id, - request_method="PUT", request_url=upload_url, - error_message=final_error_message - ) - raise LocalNetworkError(final_error_message) from last_exception - - request_logger.log_request_response( # Log final failure reason if not LocalNetworkError + request_logger.log_request_response( operation_id=operation_id, - request_method="PUT", request_url=upload_url, - error_message=final_error_message + request_method=req_meta[0], + request_url=req_meta[1], + response_status_code=exc.status, + response_headers=dict(req_meta[5]) if req_meta[5] else None, + response_content=response_content, + error_message=f"HTTP Error {exc.status}", ) - raise Exception(final_error_message) from last_exception + + logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})") + if response_content: + logging.debug(f"[DEBUG] Response content: {response_content}") + + # Retry if eligible + if status_code in self.retry_status_codes and retry_count < self.max_retries: + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + logging.warning( + "HTTP error %s. Retrying in %.2fs (%s/%s)", + status_code, + delay, + retry_count + 1, + self.max_retries, + ) + await asyncio.sleep(delay) + return await self.request( + req_meta[0], # method + req_meta[1].replace(self.base_url, ""), # path + params=req_meta[2], + data=req_meta[3], + files=req_meta[4], + headers=req_meta[5], + content_type=req_meta[6], + multipart_parser=req_meta[7], + retry_count=retry_count + 1, + ) + + raise Exception(user_friendly) from exc + + @staticmethod + def _unpack_tuple(t): + """Helper to normalise (filename, file, content_type) tuples.""" + if len(t) == 3: + return t + elif len(t) == 2: + return t[0], t[1], "application/octet-stream" + else: + raise ValueError("files tuple must be (filename, file[, content_type])") + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + timeout = aiohttp.ClientTimeout(total=self.timeout) + self._session = aiohttp.ClientSession(timeout=timeout) + self._owns_session = True + return self._session + + async def close(self) -> None: + if self._owns_session and self._session and not self._session.closed: + await self._session.close() + + async def __aenter__(self) -> "ApiClient": + """Allow usage as async‑context‑manager – ensures clean teardown""" + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() class ApiEndpoint(Generic[T, R]): @@ -763,31 +672,28 @@ class ApiEndpoint(Generic[T, R]): class SynchronousOperation(Generic[T, R]): - """ - Represents a single synchronous API operation. - """ + """Represents a single synchronous API operation.""" def __init__( self, endpoint: ApiEndpoint[T, R], request: T, - files: Optional[Dict[str, Any]] = None, + files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, api_base: str | None = None, auth_token: Optional[str] = None, comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[Dict[str,str]] = None, + auth_kwargs: Optional[Dict[str, str]] = None, timeout: float = 604800.0, verify_ssl: bool = True, content_type: str = "application/json", - multipart_parser: Callable = None, + multipart_parser: Callable | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff_factor: float = 2.0, - ): + ) -> None: self.endpoint = endpoint self.request = request - self.response = None - self.error = None + self.files = files self.api_base: str = api_base or args.comfy_api_base self.auth_token = auth_token self.comfy_api_key = comfy_api_key @@ -796,91 +702,64 @@ class SynchronousOperation(Generic[T, R]): self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) self.timeout = timeout self.verify_ssl = verify_ssl - self.files = files self.content_type = content_type self.multipart_parser = multipart_parser self.max_retries = max_retries self.retry_delay = retry_delay self.retry_backoff_factor = retry_backoff_factor - def execute(self, client: Optional[ApiClient] = None) -> R: - """Execute the API operation using the provided client or create one with retry support""" - try: - # Create client if not provided - if client is None: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - timeout=self.timeout, - verify_ssl=self.verify_ssl, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - - # Convert request model to dict, but use None for EmptyRequest - request_dict = ( - None - if isinstance(self.request, EmptyRequest) - else self.request.model_dump(exclude_none=True) + async def execute(self, client: Optional[ApiClient] = None) -> R: + owns_client = client is None + if owns_client: + client = ApiClient( + base_url=self.api_base, + auth_token=self.auth_token, + comfy_api_key=self.comfy_api_key, + timeout=self.timeout, + verify_ssl=self.verify_ssl, + max_retries=self.max_retries, + retry_delay=self.retry_delay, + retry_backoff_factor=self.retry_backoff_factor, ) - if request_dict: - for key, value in request_dict.items(): - if isinstance(value, Enum): - request_dict[key] = value.value - # Debug log for request + try: + request_dict: Optional[Dict[str, Any]] + if isinstance(self.request, EmptyRequest): + request_dict = None + else: + request_dict = self.request.model_dump(exclude_none=True) + for k, v in list(request_dict.items()): + if isinstance(v, Enum): + request_dict[k] = v.value + logging.debug( f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" ) logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") - # Make the request with built-in retry - resp = client.request( - method=self.endpoint.method.value, - path=self.endpoint.path, - data=request_dict, + response_json = await client.request( + self.endpoint.method.value, + self.endpoint.path, params=self.endpoint.query_params, + data=request_dict, files=self.files, content_type=self.content_type, - multipart_parser=self.multipart_parser + multipart_parser=self.multipart_parser, ) - # Debug log for response logging.debug("=" * 50) logging.debug("[DEBUG] RESPONSE DETAILS:") logging.debug("[DEBUG] Status Code: 200 (Success)") - logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}") + logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}") logging.debug("=" * 50) - # Parse and return the response - return self._parse_response(resp) - - except LocalNetworkError as e: - # Propagate specific network error types - logging.error(f"[ERROR] Local network error: {str(e)}") - raise - - except ApiServerError as e: - # Propagate API server errors - logging.error(f"[ERROR] API server error: {str(e)}") - raise - - except Exception as e: - logging.error(f"[ERROR] API Exception: {str(e)}") - raise Exception(str(e)) - - def _parse_response(self, resp): - """Parse response data - can be overridden by subclasses""" - # The response is already the complete object, don't extract just the "data" field - # as that would lose the outer structure (created timestamp, etc.) - - # Parse response using the provided model - self.response = self.endpoint.response_model.model_validate(resp) - logging.debug(f"[DEBUG] Parsed Response: {self.response}") - return self.response + parsed_response = self.endpoint.response_model.model_validate(response_json) + logging.debug(f"[DEBUG] Parsed Response: {parsed_response}") + return parsed_response + finally: + if owns_client: + await client.close() class TaskStatus(str, Enum): @@ -892,23 +771,21 @@ class TaskStatus(str, Enum): class PollingOperation(Generic[T, R]): - """ - Represents an asynchronous API operation that requires polling for completion. - """ + """Represents an asynchronous API operation that requires polling for completion.""" def __init__( self, poll_endpoint: ApiEndpoint[EmptyRequest, R], - completed_statuses: list, - failed_statuses: list, + completed_statuses: list[str], + failed_statuses: list[str], status_extractor: Callable[[R], str], - progress_extractor: Callable[[R], float] = None, - result_url_extractor: Callable[[R], str] = None, + progress_extractor: Callable[[R], float] | None = None, + result_url_extractor: Callable[[R], str] | None = None, request: Optional[T] = None, api_base: str | None = None, auth_token: Optional[str] = None, comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[Dict[str,str]] = None, + auth_kwargs: Optional[Dict[str, str]] = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) max_retries: int = 3, # Max retries per individual API call @@ -916,7 +793,7 @@ class PollingOperation(Generic[T, R]): retry_backoff_factor: float = 2.0, estimated_duration: Optional[float] = None, node_id: Optional[str] = None, - ): + ) -> None: self.poll_endpoint = poll_endpoint self.request = request self.api_base: str = api_base or args.comfy_api_base @@ -931,100 +808,73 @@ class PollingOperation(Generic[T, R]): self.retry_delay = retry_delay self.retry_backoff_factor = retry_backoff_factor self.estimated_duration = estimated_duration - - # Polling configuration - self.status_extractor = status_extractor or ( - lambda x: getattr(x, "status", None) - ) + self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None)) self.progress_extractor = progress_extractor self.result_url_extractor = result_url_extractor self.node_id = node_id self.completed_statuses = completed_statuses self.failed_statuses = failed_statuses + self.final_response: Optional[R] = None - # For storing response data - self.final_response = None - self.error = None - - def execute(self, client: Optional[ApiClient] = None) -> R: - """Execute the polling operation using the provided client. If failed, raise an exception.""" + async def execute(self, client: Optional[ApiClient] = None) -> R: + owns_client = client is None + if owns_client: + client = ApiClient( + base_url=self.api_base, + auth_token=self.auth_token, + comfy_api_key=self.comfy_api_key, + max_retries=self.max_retries, + retry_delay=self.retry_delay, + retry_backoff_factor=self.retry_backoff_factor, + ) try: - if client is None: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - return self._poll_until_complete(client) - except LocalNetworkError as e: - # Provide clear message for local network issues - raise Exception( - f"Polling failed due to local network issues. Please check your internet connection. " - f"Details: {str(e)}" - ) from e - except ApiServerError as e: - # Provide clear message for API server issues - raise Exception( - f"Polling failed due to API server issues. The service may be experiencing problems. " - f"Please try again later. Details: {str(e)}" - ) from e - except Exception as e: - raise Exception(f"Error during polling: {str(e)}") + return await self._poll_until_complete(client) + finally: + if owns_client: + await client.close() def _display_text_on_node(self, text: str): - """Sends text to the client which will be displayed on the node in the UI""" if not self.node_id: return - PromptServer.instance.send_progress_text(text, self.node_id) - def _display_time_progress_on_node(self, time_completed: int): + def _display_time_progress_on_node(self, time_completed: int | float): if not self.node_id: return - if self.estimated_duration is not None: - estimated_time_remaining = max( - 0, int(self.estimated_duration) - int(time_completed) - ) - message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)" + remaining = max(0, int(self.estimated_duration) - time_completed) + message = f"Task in progress: {time_completed}s (~{remaining}s remaining)" else: - message = f"Task in progress: {time_completed:.0f}s" + message = f"Task in progress: {time_completed}s" self._display_text_on_node(message) def _check_task_status(self, response: R) -> TaskStatus: - """Check task status using the status extractor function""" try: status = self.status_extractor(response) if status in self.completed_statuses: return TaskStatus.COMPLETED - elif status in self.failed_statuses: + if status in self.failed_statuses: return TaskStatus.FAILED return TaskStatus.PENDING except Exception as e: - logging.error(f"Error extracting status: {e}") + logging.error("Error extracting status: %s", e) return TaskStatus.PENDING - def _poll_until_complete(self, client: ApiClient) -> R: + async def _poll_until_complete(self, client: ApiClient) -> R: """Poll until the task is complete""" - poll_count = 0 consecutive_errors = 0 max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors if self.progress_extractor: progress = utils.ProgressBar(PROGRESS_BAR_MAX) - while poll_count < self.max_poll_attempts: + status = TaskStatus.PENDING + for poll_count in range(1, self.max_poll_attempts + 1): try: - poll_count += 1 logging.debug(f"[DEBUG] Polling attempt #{poll_count}") request_dict = ( - self.request.model_dump(exclude_none=True) - if self.request is not None - else None + None if self.request is None else self.request.model_dump(exclude_none=True) ) if poll_count == 1: @@ -1036,18 +886,14 @@ class PollingOperation(Generic[T, R]): ) # Query task status - resp = client.request( - method=self.poll_endpoint.method.value, - path=self.poll_endpoint.path, + resp = await client.request( + self.poll_endpoint.method.value, + self.poll_endpoint.path, params=self.poll_endpoint.query_params, data=request_dict, ) - - # Successfully got a response, reset consecutive error count - consecutive_errors = 0 - - # Parse response - response_obj = self.poll_endpoint.response_model.model_validate(resp) + consecutive_errors = 0 # reset on success + response_obj: R = self.poll_endpoint.response_model.model_validate(resp) # Check if task is complete status = self._check_task_status(response_obj) @@ -1065,45 +911,30 @@ class PollingOperation(Generic[T, R]): result_url = self.result_url_extractor(response_obj) if result_url: message = f"Result URL: {result_url}" - else: - message = "Task completed successfully!" logging.debug(f"[DEBUG] {message}") self._display_text_on_node(message) self.final_response = response_obj if self.progress_extractor: progress.update(100) return self.final_response - elif status == TaskStatus.FAILED: + if status == TaskStatus.FAILED: message = f"Task failed: {json.dumps(resp)}" logging.error(f"[DEBUG] {message}") raise Exception(message) - else: - logging.debug("[DEBUG] Task still pending, continuing to poll...") - - # Wait before polling again - logging.debug( - f"[DEBUG] Waiting {self.poll_interval} seconds before next poll" - ) + logging.debug("[DEBUG] Task still pending, continuing to poll...") + # Task pending – wait for i in range(int(self.poll_interval)): - time_completed = (poll_count * self.poll_interval) + i - self._display_time_progress_on_node(time_completed) - time.sleep(1) + self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i) + await asyncio.sleep(1) - except (LocalNetworkError, ApiServerError) as e: - # For network-related errors, increment error count and potentially abort + except (LocalNetworkError, ApiServerError, NetworkError) as e: consecutive_errors += 1 if consecutive_errors >= max_consecutive_errors: raise Exception( - f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}" + f"Polling aborted after {consecutive_errors} network errors: {str(e)}" ) from e - - # Log the error but continue polling - logging.warning( - f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " - f"Will retry in {self.poll_interval} seconds." - ) - time.sleep(self.poll_interval) - + logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e)) + await asyncio.sleep(self.poll_interval) except Exception as e: # For other errors, increment count and potentially abort consecutive_errors += 1 @@ -1117,10 +948,10 @@ class PollingOperation(Generic[T, R]): f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " f"Will retry in {self.poll_interval} seconds." ) - time.sleep(self.poll_interval) + await asyncio.sleep(self.poll_interval) # If we've exhausted all polling attempts raise Exception( - f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). " - f"The operation may still be running on the server but is taking longer than expected." + f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). " + "The operation may still be running on the server but is taking longer than expected." ) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index d93fbd778..c09be8d5b 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,3 +1,4 @@ +import asyncio import io from inspect import cleandoc from typing import Union, Optional @@ -28,7 +29,7 @@ from comfy_api_nodes.apinode_utils import ( import numpy as np from PIL import Image -import requests +import aiohttp import torch import base64 import time @@ -44,18 +45,18 @@ def convert_mask_to_image(mask: torch.Tensor): return mask -def handle_bfl_synchronous_operation( +async def handle_bfl_synchronous_operation( operation: SynchronousOperation, timeout_bfl_calls=360, node_id: Union[str, None] = None, ): - response_api: BFLFluxProGenerateResponse = operation.execute() - return _poll_until_generated( + response_api: BFLFluxProGenerateResponse = await operation.execute() + return await _poll_until_generated( response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id ) -def _poll_until_generated( +async def _poll_until_generated( polling_url: str, timeout=360, node_id: Union[str, None] = None ): # used bfl-comfy-nodes to verify code implementation: @@ -66,55 +67,56 @@ def _poll_until_generated( retry_404_seconds = 2 retry_202_seconds = 2 retry_pending_seconds = 1 - request = requests.Request(method=HttpMethod.GET, url=polling_url) - # NOTE: should True loop be replaced with checking if workflow has been interrupted? - while True: - if node_id: - time_elapsed = time.time() - start_time - PromptServer.instance.send_progress_text( - f"Generating ({time_elapsed:.0f}s)", node_id - ) - response = requests.Session().send(request.prepare()) - if response.status_code == 200: - result = response.json() - if result["status"] == BFLStatus.ready: - img_url = result["result"]["sample"] - if node_id: - PromptServer.instance.send_progress_text( - f"Result URL: {img_url}", node_id - ) - img_response = requests.get(img_url) - return process_image_response(img_response) - elif result["status"] in [ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - ]: - status = result["status"] - raise Exception( - f"BFL API did not return an image due to: {status}." + async with aiohttp.ClientSession() as session: + # NOTE: should True loop be replaced with checking if workflow has been interrupted? + while True: + if node_id: + time_elapsed = time.time() - start_time + PromptServer.instance.send_progress_text( + f"Generating ({time_elapsed:.0f}s)", node_id ) - elif result["status"] == BFLStatus.error: - raise Exception(f"BFL API encountered an error: {result}.") - elif result["status"] == BFLStatus.pending: - time.sleep(retry_pending_seconds) - continue - elif response.status_code == 404: - if retries_404 < max_retries_404: - retries_404 += 1 - time.sleep(retry_404_seconds) - continue - raise Exception( - f"BFL API could not find task after {max_retries_404} tries." - ) - elif response.status_code == 202: - time.sleep(retry_202_seconds) - elif time.time() - start_time > timeout: - raise Exception( - f"BFL API experienced a timeout; could not return request under {timeout} seconds." - ) - else: - raise Exception(f"BFL API encountered an error: {response.json()}") + + async with session.get(polling_url) as response: + if response.status == 200: + result = await response.json() + if result["status"] == BFLStatus.ready: + img_url = result["result"]["sample"] + if node_id: + PromptServer.instance.send_progress_text( + f"Result URL: {img_url}", node_id + ) + async with session.get(img_url) as img_resp: + return process_image_response(await img_resp.content.read()) + elif result["status"] in [ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + ]: + status = result["status"] + raise Exception( + f"BFL API did not return an image due to: {status}." + ) + elif result["status"] == BFLStatus.error: + raise Exception(f"BFL API encountered an error: {result}.") + elif result["status"] == BFLStatus.pending: + await asyncio.sleep(retry_pending_seconds) + continue + elif response.status == 404: + if retries_404 < max_retries_404: + retries_404 += 1 + await asyncio.sleep(retry_404_seconds) + continue + raise Exception( + f"BFL API could not find task after {max_retries_404} tries." + ) + elif response.status == 202: + await asyncio.sleep(retry_202_seconds) + elif time.time() - start_time > timeout: + raise Exception( + f"BFL API experienced a timeout; could not return request under {timeout} seconds." + ) + else: + raise Exception(f"BFL API encountered an error: {response.json()}") def convert_image_to_base64(image: torch.Tensor): scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048) @@ -222,7 +224,7 @@ class FluxProUltraImageNode(ComfyNodeABC): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, prompt: str, aspect_ratio: str, @@ -266,7 +268,7 @@ class FluxProUltraImageNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -354,7 +356,7 @@ class FluxKontextProImageNode(ComfyNodeABC): BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate" - def api_call( + async def api_call( self, prompt: str, aspect_ratio: str, @@ -397,7 +399,7 @@ class FluxKontextProImageNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -489,7 +491,7 @@ class FluxProImageNode(ComfyNodeABC): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, prompt: str, prompt_upsampling, @@ -524,7 +526,7 @@ class FluxProImageNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -632,7 +634,7 @@ class FluxProExpandNode(ComfyNodeABC): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, image: torch.Tensor, prompt: str, @@ -670,7 +672,7 @@ class FluxProExpandNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -744,7 +746,7 @@ class FluxProFillNode(ComfyNodeABC): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, image: torch.Tensor, mask: torch.Tensor, @@ -780,7 +782,7 @@ class FluxProFillNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -879,7 +881,7 @@ class FluxProCannyNode(ComfyNodeABC): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, control_image: torch.Tensor, prompt: str, @@ -929,7 +931,7 @@ class FluxProCannyNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -1008,7 +1010,7 @@ class FluxProDepthNode(ComfyNodeABC): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, control_image: torch.Tensor, prompt: str, @@ -1045,7 +1047,7 @@ class FluxProDepthNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index af33279d5..3751fb2a1 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -303,7 +303,7 @@ class GeminiNode(ComfyNodeABC): """ return GeminiPart(text=text) - def api_call( + async def api_call( self, prompt: str, model: GeminiModel, @@ -332,7 +332,7 @@ class GeminiNode(ComfyNodeABC): parts.extend(files) # Create response - response = SynchronousOperation( + response = await SynchronousOperation( endpoint=get_gemini_endpoint(model), request=GeminiGenerateContentRequest( contents=[ diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index b8487355f..db24e6da4 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -212,7 +212,7 @@ V3_RESOLUTIONS= [ "1536x640" ] -def download_and_process_images(image_urls): +async def download_and_process_images(image_urls): """Helper function to download and process multiple images from URLs""" # Initialize list to store image tensors @@ -220,7 +220,7 @@ def download_and_process_images(image_urls): for image_url in image_urls: # Using functions from apinode_utils.py to handle downloading and processing - image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO + image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode image_tensors.append(img_tensor) @@ -328,7 +328,7 @@ class IdeogramV1(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, turbo=False, @@ -367,7 +367,7 @@ class IdeogramV1(ComfyNodeABC): auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") @@ -378,7 +378,7 @@ class IdeogramV1(ComfyNodeABC): raise Exception("No image URLs were generated in the response") display_image_urls_on_node(image_urls, unique_id) - return (download_and_process_images(image_urls),) + return (await download_and_process_images(image_urls),) class IdeogramV2(ComfyNodeABC): @@ -487,7 +487,7 @@ class IdeogramV2(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, turbo=False, @@ -543,7 +543,7 @@ class IdeogramV2(ComfyNodeABC): auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") @@ -554,7 +554,7 @@ class IdeogramV2(ComfyNodeABC): raise Exception("No image URLs were generated in the response") display_image_urls_on_node(image_urls, unique_id) - return (download_and_process_images(image_urls),) + return (await download_and_process_images(image_urls),) class IdeogramV3(ComfyNodeABC): """ @@ -653,7 +653,7 @@ class IdeogramV3(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, image=None, @@ -774,7 +774,7 @@ class IdeogramV3(ComfyNodeABC): ) # Execute the operation and process response - response = operation.execute() + response = await operation.execute() if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") @@ -785,7 +785,7 @@ class IdeogramV3(ComfyNodeABC): raise Exception("No image URLs were generated in the response") display_image_urls_on_node(image_urls, unique_id) - return (download_and_process_images(image_urls),) + return (await download_and_process_images(image_urls),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 69e9e5cf0..9d9eb5628 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -109,7 +109,7 @@ class KlingApiError(Exception): pass -def poll_until_finished( +async def poll_until_finished( auth_kwargs: dict[str, str], api_endpoint: ApiEndpoint[Any, R], result_url_extractor: Optional[Callable[[R], str]] = None, @@ -117,7 +117,7 @@ def poll_until_finished( node_id: Optional[str] = None, ) -> R: """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" - return PollingOperation( + return await PollingOperation( poll_endpoint=api_endpoint, completed_statuses=[ KlingTaskStatus.succeed.value, @@ -278,18 +278,18 @@ def get_images_urls_from_response(response) -> Optional[str]: return None -def video_result_to_node_output( +async def video_result_to_node_output( video: KlingVideoResult, ) -> tuple[VideoFromFile, str, str]: """Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output.""" return ( - download_url_to_video_output(video.url), + await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration), ) -def image_result_to_node_output( +async def image_result_to_node_output( images: list[KlingImageResult], ) -> torch.Tensor: """ @@ -297,9 +297,9 @@ def image_result_to_node_output( If multiple images are returned, they will be stacked along the batch dimension. """ if len(images) == 1: - return download_url_to_image_tensor(images[0].url) + return await download_url_to_image_tensor(str(images[0].url)) else: - return torch.cat([download_url_to_image_tensor(image.url) for image in images]) + return torch.cat([await download_url_to_image_tensor(str(image.url)) for image in images]) class KlingNodeBase(ComfyNodeABC): @@ -467,10 +467,10 @@ class KlingTextToVideoNode(KlingNodeBase): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Text to Video Node" - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingText2VideoResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", @@ -483,7 +483,7 @@ class KlingTextToVideoNode(KlingNodeBase): node_id=node_id, ) - def api_call( + async def api_call( self, prompt: str, negative_prompt: str, @@ -519,17 +519,17 @@ class KlingTextToVideoNode(KlingNodeBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingCameraControlT2VNode(KlingTextToVideoNode): @@ -581,7 +581,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode): DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text." - def api_call( + async def api_call( self, prompt: str, negative_prompt: str, @@ -591,7 +591,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode): unique_id: Optional[str] = None, **kwargs, ): - return super().api_call( + return await super().api_call( model_name=KlingVideoGenModelName.kling_v1, cfg_scale=cfg_scale, mode=KlingVideoGenMode.std, @@ -670,10 +670,10 @@ class KlingImage2VideoNode(KlingNodeBase): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Image to Video Node" - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingImage2VideoResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", @@ -686,7 +686,7 @@ class KlingImage2VideoNode(KlingNodeBase): node_id=node_id, ) - def api_call( + async def api_call( self, start_frame: torch.Tensor, prompt: str, @@ -733,17 +733,17 @@ class KlingImage2VideoNode(KlingNodeBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingCameraControlI2VNode(KlingImage2VideoNode): @@ -798,7 +798,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode): DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image." - def api_call( + async def api_call( self, start_frame: torch.Tensor, prompt: str, @@ -809,7 +809,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode): unique_id: Optional[str] = None, **kwargs, ): - return super().api_call( + return await super().api_call( model_name=KlingVideoGenModelName.kling_v1_5, start_frame=start_frame, cfg_scale=cfg_scale, @@ -897,7 +897,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last." - def api_call( + async def api_call( self, start_frame: torch.Tensor, end_frame: torch.Tensor, @@ -912,7 +912,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode): mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[ mode ] - return super().api_call( + return await super().api_call( prompt=prompt, negative_prompt=negative_prompt, model_name=model_name, @@ -964,10 +964,10 @@ class KlingVideoExtendNode(KlingNodeBase): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes." - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingVideoExtendResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_VIDEO_EXTEND}/{task_id}", @@ -980,7 +980,7 @@ class KlingVideoExtendNode(KlingNodeBase): node_id=node_id, ) - def api_call( + async def api_call( self, prompt: str, negative_prompt: str, @@ -1006,17 +1006,17 @@ class KlingVideoExtendNode(KlingNodeBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingVideoEffectsBase(KlingNodeBase): @@ -1025,10 +1025,10 @@ class KlingVideoEffectsBase(KlingNodeBase): RETURN_TYPES = ("VIDEO", "STRING", "STRING") RETURN_NAMES = ("VIDEO", "video_id", "duration") - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingVideoEffectsResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_VIDEO_EFFECTS}/{task_id}", @@ -1041,7 +1041,7 @@ class KlingVideoEffectsBase(KlingNodeBase): node_id=node_id, ) - def api_call( + async def api_call( self, dual_character: bool, effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, @@ -1084,17 +1084,17 @@ class KlingVideoEffectsBase(KlingNodeBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): @@ -1142,7 +1142,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): RETURN_TYPES = ("VIDEO", "STRING") RETURN_NAMES = ("VIDEO", "duration") - def api_call( + async def api_call( self, image_left: torch.Tensor, image_right: torch.Tensor, @@ -1153,7 +1153,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): unique_id: Optional[str] = None, **kwargs, ): - video, _, duration = super().api_call( + video, _, duration = await super().api_call( dual_character=True, effect_scene=effect_scene, model_name=model_name, @@ -1208,7 +1208,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene." - def api_call( + async def api_call( self, image: torch.Tensor, effect_scene: KlingSingleImageEffectsScene, @@ -1217,7 +1217,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase): unique_id: Optional[str] = None, **kwargs, ): - return super().api_call( + return await super().api_call( dual_character=False, effect_scene=effect_scene, model_name=model_name, @@ -1253,11 +1253,11 @@ class KlingLipSyncBase(KlingNodeBase): f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters." ) - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingLipSyncResponse: """Polls the Kling API endpoint until the task reaches a terminal state.""" - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_LIP_SYNC}/{task_id}", @@ -1270,7 +1270,7 @@ class KlingLipSyncBase(KlingNodeBase): node_id=node_id, ) - def api_call( + async def api_call( self, video: VideoInput, audio: Optional[AudioInput] = None, @@ -1287,12 +1287,12 @@ class KlingLipSyncBase(KlingNodeBase): self.validate_lip_sync_video(video) # Upload video to Comfy API and get download URL - video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs) + video_url = await upload_video_to_comfyapi(video, auth_kwargs=kwargs) logging.info("Uploaded video to Comfy API. URL: %s", video_url) # Upload the audio file to Comfy API and get download URL if audio: - audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs) + audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=kwargs) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: audio_url = None @@ -1319,17 +1319,17 @@ class KlingLipSyncBase(KlingNodeBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): @@ -1357,7 +1357,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." - def api_call( + async def api_call( self, video: VideoInput, audio: AudioInput, @@ -1365,7 +1365,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): unique_id: Optional[str] = None, **kwargs, ): - return super().api_call( + return await super().api_call( video=video, audio=audio, voice_language=voice_language, @@ -1469,7 +1469,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase): DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." - def api_call( + async def api_call( self, video: VideoInput, text: str, @@ -1479,7 +1479,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase): **kwargs, ): voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice] - return super().api_call( + return await super().api_call( video=video, text=text, voice_language=voice_language, @@ -1533,10 +1533,10 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background." - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingVirtualTryOnResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", @@ -1549,7 +1549,7 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): node_id=node_id, ) - def api_call( + async def api_call( self, human_image: torch.Tensor, cloth_image: torch.Tensor, @@ -1572,17 +1572,17 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_image_result_response(final_response) images = get_images_from_response(final_response) - return (image_result_to_node_output(images),) + return (await image_result_to_node_output(images),) class KlingImageGenerationNode(KlingImageGenerationBase): @@ -1655,13 +1655,13 @@ class KlingImageGenerationNode(KlingImageGenerationBase): DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image." - def get_response( + async def get_response( self, task_id: str, auth_kwargs: Optional[dict[str, str]], node_id: Optional[str] = None, ) -> KlingImageGenerationsResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", @@ -1674,7 +1674,7 @@ class KlingImageGenerationNode(KlingImageGenerationBase): node_id=node_id, ) - def api_call( + async def api_call( self, model_name: KlingImageGenModelName, prompt: str, @@ -1714,17 +1714,17 @@ class KlingImageGenerationNode(KlingImageGenerationBase): auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_image_result_response(final_response) images = get_images_from_response(final_response) - return (image_result_to_node_output(images),) + return (await image_result_to_node_output(images),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 525dc38e6..b3c32bed5 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -38,7 +38,7 @@ from comfy_api_nodes.apinode_utils import ( ) from server import PromptServer -import requests +import aiohttp import torch from io import BytesIO @@ -217,7 +217,7 @@ class LumaImageGenerationNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, prompt: str, model: str, @@ -234,19 +234,19 @@ class LumaImageGenerationNode(ComfyNodeABC): # handle image_luma_ref api_image_ref = None if image_luma_ref is not None: - api_image_ref = self._convert_luma_refs( + api_image_ref = await self._convert_luma_refs( image_luma_ref, max_refs=4, auth_kwargs=kwargs, ) # handle style_luma_ref api_style_ref = None if style_image is not None: - api_style_ref = self._convert_style_image( + api_style_ref = await self._convert_style_image( style_image, weight=style_image_weight, auth_kwargs=kwargs, ) # handle character_ref images character_ref = None if character_image is not None: - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( character_image, max_images=4, auth_kwargs=kwargs, ) character_ref = LumaCharacterRef( @@ -270,7 +270,7 @@ class LumaImageGenerationNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() + response_api: LumaGeneration = await operation.execute() operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -286,19 +286,20 @@ class LumaImageGenerationNode(ComfyNodeABC): node_id=unique_id, auth_kwargs=kwargs, ) - response_poll = operation.execute() + response_poll = await operation.execute() - img_response = requests.get(response_poll.assets.image) - img = process_image_response(img_response) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.assets.image) as img_response: + img = process_image_response(await img_response.content.read()) return (img,) - def _convert_luma_refs( + async def _convert_luma_refs( self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None ): luma_urls = [] ref_count = 0 for ref in luma_ref.refs: - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( ref.image, max_images=1, auth_kwargs=auth_kwargs ) luma_urls.append(download_urls[0]) @@ -307,13 +308,13 @@ class LumaImageGenerationNode(ComfyNodeABC): break return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) - def _convert_style_image( + async def _convert_style_image( self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None ): chain = LumaReferenceChain( first_ref=LumaReference(image=style_image, weight=weight) ) - return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) + return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) class LumaImageModifyNode(ComfyNodeABC): @@ -370,7 +371,7 @@ class LumaImageModifyNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, prompt: str, model: str, @@ -381,7 +382,7 @@ class LumaImageModifyNode(ComfyNodeABC): **kwargs, ): # first, upload image - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( image, max_images=1, auth_kwargs=kwargs, ) image_url = download_urls[0] @@ -402,7 +403,7 @@ class LumaImageModifyNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() + response_api: LumaGeneration = await operation.execute() operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -418,10 +419,11 @@ class LumaImageModifyNode(ComfyNodeABC): node_id=unique_id, auth_kwargs=kwargs, ) - response_poll = operation.execute() + response_poll = await operation.execute() - img_response = requests.get(response_poll.assets.image) - img = process_image_response(img_response) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.assets.image) as img_response: + img = process_image_response(await img_response.content.read()) return (img,) @@ -494,7 +496,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, prompt: str, model: str, @@ -529,7 +531,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() + response_api: LumaGeneration = await operation.execute() if unique_id: PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) @@ -549,10 +551,11 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC): estimated_duration=LUMA_T2V_AVERAGE_DURATION, auth_kwargs=kwargs, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.assets.video) - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.assets.video) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) class LumaImageToVideoGenerationNode(ComfyNodeABC): @@ -626,7 +629,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, prompt: str, model: str, @@ -644,7 +647,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): raise Exception( "At least one of first_image and last_image requires an input." ) - keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs) + keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None @@ -667,7 +670,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() + response_api: LumaGeneration = await operation.execute() if unique_id: PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) @@ -687,12 +690,13 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): estimated_duration=LUMA_I2V_AVERAGE_DURATION, auth_kwargs=kwargs, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.assets.video) - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.assets.video) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) - def _convert_to_keyframes( + async def _convert_to_keyframes( self, first_image: torch.Tensor = None, last_image: torch.Tensor = None, @@ -703,12 +707,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC): frame0 = None frame1 = None if first_image is not None: - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( first_image, max_images=1, auth_kwargs=auth_kwargs, ) frame0 = LumaImageReference(type="image", url=download_urls[0]) if last_image is not None: - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( last_image, max_images=1, auth_kwargs=auth_kwargs, ) frame1 = LumaImageReference(type="image", url=download_urls[0]) diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 9b46636db..58d2ed90c 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -86,7 +86,7 @@ class MinimaxTextToVideoNode: API_NODE = True OUTPUT_NODE = True - def generate_video( + async def generate_video( self, prompt_text, seed=0, @@ -104,12 +104,12 @@ class MinimaxTextToVideoNode: # upload image, if passed in image_url = None if image is not None: - image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0] + image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0] # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model subject_reference = None if subject is not None: - subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0] + subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0] subject_reference = [SubjectReferenceItem(image=subject_url)] @@ -130,7 +130,7 @@ class MinimaxTextToVideoNode: ), auth_kwargs=kwargs, ) - response = video_generate_operation.execute() + response = await video_generate_operation.execute() task_id = response.task_id if not task_id: @@ -151,7 +151,7 @@ class MinimaxTextToVideoNode: node_id=unique_id, auth_kwargs=kwargs, ) - task_result = video_generate_operation.execute() + task_result = await video_generate_operation.execute() file_id = task_result.file_id if file_id is None: @@ -167,7 +167,7 @@ class MinimaxTextToVideoNode: request=EmptyRequest(), auth_kwargs=kwargs, ) - file_result = file_retrieve_operation.execute() + file_result = await file_retrieve_operation.execute() file_url = file_result.file.download_url if file_url is None: @@ -182,7 +182,7 @@ class MinimaxTextToVideoNode: message = f"Result URL: {file_url}" PromptServer.instance.send_progress_text(message, unique_id) - video_io = download_url_to_bytesio(file_url) + video_io = await download_url_to_bytesio(file_url) if video_io is None: error_msg = f"Failed to download video from {file_url}" logging.error(error_msg) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 789fcef02..164ca3ea5 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -95,14 +95,14 @@ def get_video_url_from_response(response) -> Optional[str]: return None -def poll_until_finished( +async def poll_until_finished( auth_kwargs: dict[str, str], api_endpoint: ApiEndpoint[Any, R], result_url_extractor: Optional[Callable[[R], str]] = None, node_id: Optional[str] = None, ) -> R: """Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response.""" - return PollingOperation( + return await PollingOperation( poll_endpoint=api_endpoint, completed_statuses=[ "completed", @@ -394,10 +394,10 @@ class BaseMoonvalleyVideoNode: else: return control_map["Motion Transfer"] - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> MoonvalleyPromptResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{API_PROMPTS_ENDPOINT}/{task_id}", @@ -507,7 +507,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): RETURN_NAMES = ("video",) DESCRIPTION = "Moonvalley Marey Image to Video Node" - def generate( + async def generate( self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): image = kwargs.get("image", None) @@ -532,9 +532,9 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): # Get MIME type from tensor - assuming PNG format for image tensors mime_type = "image/png" - image_url = upload_images_to_comfyapi( + image_url = (await upload_images_to_comfyapi( image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type - )[0] + ))[0] request = MoonvalleyTextToVideoRequest( image_url=image_url, prompt_text=prompt, inference_params=inference_params @@ -549,14 +549,14 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): request=request, auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) - video = download_url_to_video_output(final_response.output_url) + video = await download_url_to_video_output(final_response.output_url) return (video,) @@ -609,7 +609,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): RETURN_TYPES = ("VIDEO",) RETURN_NAMES = ("video",) - def generate( + async def generate( self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): video = kwargs.get("video") @@ -620,7 +620,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): video_url = "" if video: validated_video = validate_video_to_video_input(video) - video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) + video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) control_type = kwargs.get("control_type") motion_intensity = kwargs.get("motion_intensity") @@ -658,15 +658,15 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): request=request, auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) - video = download_url_to_video_output(final_response.output_url) + video = await download_url_to_video_output(final_response.output_url) return (video,) @@ -688,7 +688,7 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): del input_types["optional"][param] return input_types - def generate( + async def generate( self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) @@ -717,15 +717,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): request=request, auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) - video = download_url_to_video_output(final_response.output_url) + video = await download_url_to_video_output(final_response.output_url) return (video,) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index be1d2de4a..ab3c5363b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -163,7 +163,7 @@ class OpenAIDalle2(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, seed=0, @@ -233,9 +233,9 @@ class OpenAIDalle2(ComfyNodeABC): auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() - img_tensor = validate_and_cast_response(response, node_id=unique_id) + img_tensor = await validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -311,7 +311,7 @@ class OpenAIDalle3(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, seed=0, @@ -343,9 +343,9 @@ class OpenAIDalle3(ComfyNodeABC): auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() - img_tensor = validate_and_cast_response(response, node_id=unique_id) + img_tensor = await validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -446,7 +446,7 @@ class OpenAIGPTImage1(ComfyNodeABC): DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, seed=0, @@ -537,9 +537,9 @@ class OpenAIGPTImage1(ComfyNodeABC): auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() - img_tensor = validate_and_cast_response(response, node_id=unique_id) + img_tensor = await validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -623,7 +623,7 @@ class OpenAIChatNode(OpenAITextNode): DESCRIPTION = "Generate text responses from an OpenAI model." - def get_result_response( + async def get_result_response( self, response_id: str, include: Optional[list[Includable]] = None, @@ -639,7 +639,7 @@ class OpenAIChatNode(OpenAITextNode): creation above for more information. """ - return PollingOperation( + return await PollingOperation( poll_endpoint=ApiEndpoint( path=f"{RESPONSES_ENDPOINT}/{response_id}", method=HttpMethod.GET, @@ -784,7 +784,7 @@ class OpenAIChatNode(OpenAITextNode): self.history[session_id] = new_history - def api_call( + async def api_call( self, prompt: str, persist_context: bool, @@ -815,7 +815,7 @@ class OpenAIChatNode(OpenAITextNode): previous_response_id = None # Create response - create_response = SynchronousOperation( + create_response = await SynchronousOperation( endpoint=ApiEndpoint( path=RESPONSES_ENDPOINT, method=HttpMethod.POST, @@ -848,7 +848,7 @@ class OpenAIChatNode(OpenAITextNode): response_id = create_response.id # Get result output - result_response = self.get_result_response(response_id, auth_kwargs=kwargs) + result_response = await self.get_result_response(response_id, auth_kwargs=kwargs) output_text = self.parse_output_text_from_response(result_response) # Update history diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 1cc708564..a8dc43cb3 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -122,7 +122,7 @@ class PikaNodeBase(ComfyNodeABC): FUNCTION = "api_call" RETURN_TYPES = ("VIDEO",) - def poll_for_task_status( + async def poll_for_task_status( self, task_id: str, auth_kwargs: Optional[dict[str, str]] = None, @@ -152,9 +152,9 @@ class PikaNodeBase(ComfyNodeABC): node_id=node_id, estimated_duration=60 ) - return polling_operation.execute() + return await polling_operation.execute() - def execute_task( + async def execute_task( self, initial_operation: SynchronousOperation[R, PikaGenerateResponse], auth_kwargs: Optional[dict[str, str]] = None, @@ -169,14 +169,14 @@ class PikaNodeBase(ComfyNodeABC): Returns: A tuple containing the video file as a VIDEO output. """ - initial_response = initial_operation.execute() + initial_response = await initial_operation.execute() if not is_valid_initial_response(initial_response): error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}" logging.error(error_msg) raise PikaApiError(error_msg) task_id = initial_response.video_id - final_response = self.poll_for_task_status(task_id, auth_kwargs) + final_response = await self.poll_for_task_status(task_id, auth_kwargs) if not is_valid_video_response(final_response): error_msg = ( f"Pika task {task_id} succeeded but no video data found in response." @@ -187,7 +187,7 @@ class PikaNodeBase(ComfyNodeABC): video_url = str(final_response.url) logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url) - return (download_url_to_video_output(video_url),) + return (await download_url_to_video_output(video_url),) class PikaImageToVideoV2_2(PikaNodeBase): @@ -212,7 +212,7 @@ class PikaImageToVideoV2_2(PikaNodeBase): DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video." - def api_call( + async def api_call( self, image: torch.Tensor, prompt_text: str, @@ -251,7 +251,7 @@ class PikaImageToVideoV2_2(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaTextToVideoNodeV2_2(PikaNodeBase): @@ -281,7 +281,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase): DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video." - def api_call( + async def api_call( self, prompt_text: str, negative_prompt: str, @@ -311,7 +311,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase): content_type="application/x-www-form-urlencoded", ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaScenesV2_2(PikaNodeBase): @@ -361,7 +361,7 @@ class PikaScenesV2_2(PikaNodeBase): DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them." - def api_call( + async def api_call( self, prompt_text: str, negative_prompt: str, @@ -420,7 +420,7 @@ class PikaScenesV2_2(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikAdditionsNode(PikaNodeBase): @@ -462,7 +462,7 @@ class PikAdditionsNode(PikaNodeBase): DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result." - def api_call( + async def api_call( self, video: VideoInput, image: torch.Tensor, @@ -481,10 +481,10 @@ class PikAdditionsNode(PikaNodeBase): image_bytes_io = tensor_to_bytesio(image) image_bytes_io.seek(0) - pika_files = [ - ("video", ("video.mp4", video_bytes_io, "video/mp4")), - ("image", ("image.png", image_bytes_io, "image/png")), - ] + pika_files = { + "video": ("video.mp4", video_bytes_io, "video/mp4"), + "image": ("image.png", image_bytes_io, "image/png"), + } # Prepare non-file data pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost( @@ -506,7 +506,7 @@ class PikAdditionsNode(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaSwapsNode(PikaNodeBase): @@ -558,7 +558,7 @@ class PikaSwapsNode(PikaNodeBase): DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates." RETURN_TYPES = ("VIDEO",) - def api_call( + async def api_call( self, video: VideoInput, image: torch.Tensor, @@ -587,11 +587,11 @@ class PikaSwapsNode(PikaNodeBase): image_bytes_io = tensor_to_bytesio(image) image_bytes_io.seek(0) - pika_files = [ - ("video", ("video.mp4", video_bytes_io, "video/mp4")), - ("image", ("image.png", image_bytes_io, "image/png")), - ("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")), - ] + pika_files = { + "video": ("video.mp4", video_bytes_io, "video/mp4"), + "image": ("image.png", image_bytes_io, "image/png"), + "modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"), + } # Prepare non-file data pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost( @@ -613,7 +613,7 @@ class PikaSwapsNode(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaffectsNode(PikaNodeBase): @@ -664,7 +664,7 @@ class PikaffectsNode(PikaNodeBase): DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear" - def api_call( + async def api_call( self, image: torch.Tensor, pikaffect: str, @@ -693,7 +693,7 @@ class PikaffectsNode(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaStartEndFrameNode2_2(PikaNodeBase): @@ -718,7 +718,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them." - def api_call( + async def api_call( self, image_start: torch.Tensor, image_end: torch.Tensor, @@ -732,10 +732,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): ) -> tuple[VideoFromFile]: pika_files = [ - ( - "keyFrames", - ("image_start.png", tensor_to_bytesio(image_start), "image/png"), - ), + ("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")), ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")), ] @@ -758,7 +755,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase): auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index ef4a9a802..7c5a52feb 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -30,7 +30,7 @@ from comfy.comfy_types.node_typing import IO, ComfyNodeABC from comfy_api.input_impl import VideoFromFile import torch -import requests +import aiohttp from io import BytesIO @@ -47,7 +47,7 @@ def get_video_url_from_response( return str(response.Resp.url) -def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): +async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): # first, upload image to Pixverse and get image id to use in actual generation call files = {"image": tensor_to_bytesio(image)} operation = SynchronousOperation( @@ -62,7 +62,7 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): content_type="multipart/form-data", auth_kwargs=auth_kwargs, ) - response_upload: PixverseImageUploadResponse = operation.execute() + response_upload: PixverseImageUploadResponse = await operation.execute() if response_upload.Resp is None: raise Exception( @@ -164,7 +164,7 @@ class PixverseTextToVideoNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, prompt: str, aspect_ratio: str, @@ -205,7 +205,7 @@ class PixverseTextToVideoNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") @@ -229,11 +229,11 @@ class PixverseTextToVideoNode(ComfyNodeABC): result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.Resp.url) - - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.Resp.url) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) class PixverseImageToVideoNode(ComfyNodeABC): @@ -302,7 +302,7 @@ class PixverseImageToVideoNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, image: torch.Tensor, prompt: str, @@ -316,7 +316,7 @@ class PixverseImageToVideoNode(ComfyNodeABC): **kwargs, ): validate_string(prompt, strip_whitespace=False) - img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs) + img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -345,7 +345,7 @@ class PixverseImageToVideoNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") @@ -369,10 +369,11 @@ class PixverseImageToVideoNode(ComfyNodeABC): result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_I2V, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.Resp.url) - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.Resp.url) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) class PixverseTransitionVideoNode(ComfyNodeABC): @@ -436,7 +437,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC): }, } - def api_call( + async def api_call( self, first_frame: torch.Tensor, last_frame: torch.Tensor, @@ -450,8 +451,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC): **kwargs, ): validate_string(prompt, strip_whitespace=False) - first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs) - last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs) + first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs) + last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -480,7 +481,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC): ), auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") @@ -504,10 +505,11 @@ class PixverseTransitionVideoNode(ComfyNodeABC): result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.Resp.url) - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.Resp.url) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index e369c4b7e..c8516b368 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -37,7 +37,7 @@ from io import BytesIO from PIL import UnidentifiedImageError -def handle_recraft_file_request( +async def handle_recraft_file_request( image: torch.Tensor, path: str, mask: torch.Tensor=None, @@ -71,13 +71,13 @@ def handle_recraft_file_request( auth_kwargs=auth_kwargs, multipart_parser=recraft_multipart_parser, ) - response: RecraftImageGenerationResponse = operation.execute() + response: RecraftImageGenerationResponse = await operation.execute() all_bytesio = [] if response.image is not None: - all_bytesio.append(download_url_to_bytesio(response.image.url, timeout=timeout)) + all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout)) else: for data in response.data: - all_bytesio.append(download_url_to_bytesio(data.url, timeout=timeout)) + all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout)) return all_bytesio @@ -395,7 +395,7 @@ class RecraftTextToImageNode: }, } - def api_call( + async def api_call( self, prompt: str, size: str, @@ -439,7 +439,7 @@ class RecraftTextToImageNode: ), auth_kwargs=kwargs, ) - response: RecraftImageGenerationResponse = operation.execute() + response: RecraftImageGenerationResponse = await operation.execute() images = [] urls = [] for data in response.data: @@ -451,7 +451,7 @@ class RecraftTextToImageNode: f"Result URL: {urls_string}", unique_id ) image = bytesio_to_image_tensor( - download_url_to_bytesio(data.url, timeout=1024) + await download_url_to_bytesio(data.url, timeout=1024) ) if len(image.shape) < 4: image = image.unsqueeze(0) @@ -538,7 +538,7 @@ class RecraftImageToImageNode: }, } - def api_call( + async def api_call( self, image: torch.Tensor, prompt: str, @@ -578,7 +578,7 @@ class RecraftImageToImageNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path="/proxy/recraft/images/imageToImage", request=request, @@ -654,7 +654,7 @@ class RecraftImageInpaintingNode: }, } - def api_call( + async def api_call( self, image: torch.Tensor, mask: torch.Tensor, @@ -690,7 +690,7 @@ class RecraftImageInpaintingNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], mask=mask[i:i+1], path="/proxy/recraft/images/inpaint", @@ -779,7 +779,7 @@ class RecraftTextToVectorNode: }, } - def api_call( + async def api_call( self, prompt: str, substyle: str, @@ -821,7 +821,7 @@ class RecraftTextToVectorNode: ), auth_kwargs=kwargs, ) - response: RecraftImageGenerationResponse = operation.execute() + response: RecraftImageGenerationResponse = await operation.execute() svg_data = [] urls = [] for data in response.data: @@ -831,7 +831,7 @@ class RecraftTextToVectorNode: PromptServer.instance.send_progress_text( f"Result URL: {' '.join(urls)}", unique_id ) - svg_data.append(download_url_to_bytesio(data.url, timeout=1024)) + svg_data.append(await download_url_to_bytesio(data.url, timeout=1024)) return (SVG(svg_data),) @@ -861,7 +861,7 @@ class RecraftVectorizeImageNode: }, } - def api_call( + async def api_call( self, image: torch.Tensor, **kwargs, @@ -870,7 +870,7 @@ class RecraftVectorizeImageNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path="/proxy/recraft/images/vectorize", auth_kwargs=kwargs, @@ -942,7 +942,7 @@ class RecraftReplaceBackgroundNode: }, } - def api_call( + async def api_call( self, image: torch.Tensor, prompt: str, @@ -973,7 +973,7 @@ class RecraftReplaceBackgroundNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path="/proxy/recraft/images/replaceBackground", request=request, @@ -1011,7 +1011,7 @@ class RecraftRemoveBackgroundNode: }, } - def api_call( + async def api_call( self, image: torch.Tensor, **kwargs, @@ -1020,7 +1020,7 @@ class RecraftRemoveBackgroundNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path="/proxy/recraft/images/removeBackground", auth_kwargs=kwargs, @@ -1062,7 +1062,7 @@ class RecraftCrispUpscaleNode: }, } - def api_call( + async def api_call( self, image: torch.Tensor, **kwargs, @@ -1071,7 +1071,7 @@ class RecraftCrispUpscaleNode: total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path=self.RECRAFT_PATH, auth_kwargs=kwargs, diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 67f90478c..c89d087e5 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -9,11 +9,10 @@ from __future__ import annotations from inspect import cleandoc from comfy.comfy_types.node_typing import IO import folder_paths as comfy_paths -import requests +import aiohttp import os import datetime -import shutil -import time +import asyncio import io import logging import math @@ -66,7 +65,6 @@ def create_task_error(response: Rodin3DGenerateResponse): return hasattr(response, "error") - class Rodin3DAPI: """ Generate 3D Assets using Rodin API @@ -123,8 +121,8 @@ class Rodin3DAPI: else: return "Generating" - def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs): - if images == None: + async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs): + if images is None: raise Exception("Rodin 3D generate requires at least 1 image.") if len(images) >= 5: raise Exception("Rodin 3D generate requires up to 5 image.") @@ -155,7 +153,7 @@ class Rodin3DAPI: auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() if create_task_error(response): error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" @@ -168,7 +166,7 @@ class Rodin3DAPI: logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}") return task_uuid, subscription_key - def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse: + async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse: path = "/proxy/rodin/api/v2/status" @@ -191,11 +189,9 @@ class Rodin3DAPI: logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") - return poll_operation.execute() + return await poll_operation.execute() - - - def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse: + async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse: logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") path = "/proxy/rodin/api/v2/download" @@ -212,53 +208,59 @@ class Rodin3DAPI: auth_kwargs=kwargs ) - return operation.execute() + return await operation.execute() - def GetQualityAndMode(self, PolyCount): - if PolyCount == "200K-Triangle": + def get_quality_mode(self, poly_count): + if poly_count == "200K-Triangle": mesh_mode = "Raw" quality = "medium" else: mesh_mode = "Quad" - if PolyCount == "4K-Quad": + if poly_count == "4K-Quad": quality = "extra-low" - elif PolyCount == "8K-Quad": + elif poly_count == "8K-Quad": quality = "low" - elif PolyCount == "18K-Quad": + elif poly_count == "18K-Quad": quality = "medium" - elif PolyCount == "50K-Quad": + elif poly_count == "50K-Quad": quality = "high" else: quality = "medium" return mesh_mode, quality - def DownLoadFiles(self, Url_List): - Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) - os.makedirs(Save_path, exist_ok=True) + async def download_files(self, url_list): + save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) + os.makedirs(save_path, exist_ok=True) model_file_path = None - for Item in Url_List.list: - url = Item.url - file_name = Item.name - file_path = os.path.join(Save_path, file_name) - if file_path.endswith(".glb"): - model_file_path = file_path - logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") - max_retries = 5 - for attempt in range(max_retries): - try: - with requests.get(url, stream=True) as r: - r.raise_for_status() - with open(file_path, "wb") as f: - shutil.copyfileobj(r.raw, f) - break - except Exception as e: - logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") - if attempt < max_retries - 1: - logging.info("Retrying...") - time.sleep(2) - else: - logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.") + async with aiohttp.ClientSession() as session: + for i in url_list.list: + url = i.url + file_name = i.name + file_path = os.path.join(save_path, file_name) + if file_path.endswith(".glb"): + model_file_path = file_path + logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") + max_retries = 5 + for attempt in range(max_retries): + try: + async with session.get(url) as resp: + resp.raise_for_status() + with open(file_path, "wb") as f: + async for chunk in resp.content.iter_chunked(32 * 1024): + f.write(chunk) + break + except Exception as e: + logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") + if attempt < max_retries - 1: + logging.info("Retrying...") + await asyncio.sleep(2) + else: + logging.info( + "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", + file_path, + max_retries, + ) return model_file_path @@ -285,7 +287,7 @@ class Rodin3D_Regular(Rodin3DAPI): }, } - def api_call( + async def api_call( self, Images, Seed, @@ -298,14 +300,17 @@ class Rodin3D_Regular(Rodin3DAPI): m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.GetQualityAndMode(Polygon_count) - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + mesh_mode, quality = self.get_quality_mode(Polygon_count) + task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, + quality=quality, tier=tier, mesh_mode=mesh_mode, + **kwargs) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) return (model,) + class Rodin3D_Detail(Rodin3DAPI): @classmethod def INPUT_TYPES(s): @@ -328,7 +333,7 @@ class Rodin3D_Detail(Rodin3DAPI): }, } - def api_call( + async def api_call( self, Images, Seed, @@ -341,14 +346,17 @@ class Rodin3D_Detail(Rodin3DAPI): m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.GetQualityAndMode(Polygon_count) - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + mesh_mode, quality = self.get_quality_mode(Polygon_count) + task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, + quality=quality, tier=tier, mesh_mode=mesh_mode, + **kwargs) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) return (model,) + class Rodin3D_Smooth(Rodin3DAPI): @classmethod def INPUT_TYPES(s): @@ -371,7 +379,7 @@ class Rodin3D_Smooth(Rodin3DAPI): }, } - def api_call( + async def api_call( self, Images, Seed, @@ -384,14 +392,17 @@ class Rodin3D_Smooth(Rodin3DAPI): m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.GetQualityAndMode(Polygon_count) - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + mesh_mode, quality = self.get_quality_mode(Polygon_count) + task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, + quality=quality, tier=tier, mesh_mode=mesh_mode, + **kwargs) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) return (model,) + class Rodin3D_Sketch(Rodin3DAPI): @classmethod def INPUT_TYPES(s): @@ -423,7 +434,7 @@ class Rodin3D_Sketch(Rodin3DAPI): }, } - def api_call( + async def api_call( self, Images, Seed, @@ -437,10 +448,12 @@ class Rodin3D_Sketch(Rodin3DAPI): material_type = "PBR" quality = "medium" mesh_mode = "Quad" - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + task_uuid, subscription_key = await self.create_generate_task( + images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs + ) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) return (model,) diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index af4b321f9..98024a9fa 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -99,14 +99,14 @@ def validate_input_image(image: torch.Tensor) -> bool: return image.shape[2] < 8000 and image.shape[1] < 8000 -def poll_until_finished( +async def poll_until_finished( auth_kwargs: dict[str, str], api_endpoint: ApiEndpoint[Any, TaskStatusResponse], estimated_duration: Optional[int] = None, node_id: Optional[str] = None, ) -> TaskStatusResponse: """Polls the Runway API endpoint until the task reaches a terminal state, then returns the response.""" - return PollingOperation( + return await PollingOperation( poll_endpoint=api_endpoint, completed_statuses=[ TaskStatus.SUCCEEDED.value, @@ -115,7 +115,7 @@ def poll_until_finished( TaskStatus.FAILED.value, TaskStatus.CANCELLED.value, ], - status_extractor=lambda response: (response.status.value), + status_extractor=lambda response: response.status.value, auth_kwargs=auth_kwargs, result_url_extractor=get_video_url_from_task_status, estimated_duration=estimated_duration, @@ -167,11 +167,11 @@ class RunwayVideoGenNode(ComfyNodeABC): ) return True - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> RunwayImageToVideoResponse: """Poll the task status until it is finished then get the response.""" - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_GET_TASK_STATUS}/{task_id}", @@ -183,7 +183,7 @@ class RunwayVideoGenNode(ComfyNodeABC): node_id=node_id, ) - def generate_video( + async def generate_video( self, request: RunwayImageToVideoRequest, auth_kwargs: dict[str, str], @@ -200,15 +200,15 @@ class RunwayVideoGenNode(ComfyNodeABC): auth_kwargs=auth_kwargs, ) - initial_response = initial_operation.execute() + initial_response = await initial_operation.execute() self.validate_task_created(initial_response) task_id = initial_response.id - final_response = self.get_response(task_id, auth_kwargs, node_id) + final_response = await self.get_response(task_id, auth_kwargs, node_id) self.validate_response(final_response) video_url = get_video_url_from_task_status(final_response) - return (download_url_to_video_output(video_url),) + return (await download_url_to_video_output(video_url),) class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): @@ -250,7 +250,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): }, } - def api_call( + async def api_call( self, prompt: str, start_frame: torch.Tensor, @@ -265,7 +265,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): validate_input_image(start_frame) # Upload image - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( start_frame, max_images=1, mime_type="image/png", @@ -274,7 +274,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): if len(download_urls) != 1: raise RunwayApiError("Failed to upload one or more images to comfy api.") - return self.generate_video( + return await self.generate_video( RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -333,7 +333,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode): }, } - def api_call( + async def api_call( self, prompt: str, start_frame: torch.Tensor, @@ -348,7 +348,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode): validate_input_image(start_frame) # Upload image - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( start_frame, max_images=1, mime_type="image/png", @@ -357,7 +357,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode): if len(download_urls) != 1: raise RunwayApiError("Failed to upload one or more images to comfy api.") - return self.generate_video( + return await self.generate_video( RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -382,10 +382,10 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode): DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3." - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> RunwayImageToVideoResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_GET_TASK_STATUS}/{task_id}", @@ -437,7 +437,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode): }, } - def api_call( + async def api_call( self, prompt: str, start_frame: torch.Tensor, @@ -455,7 +455,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode): # Upload images stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( stacked_input_images, max_images=2, mime_type="image/png", @@ -464,7 +464,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode): if len(download_urls) != 2: raise RunwayApiError("Failed to upload one or more images to comfy api.") - return self.generate_video( + return await self.generate_video( RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -543,11 +543,11 @@ class RunwayTextToImageNode(ComfyNodeABC): ) return True - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> TaskStatusResponse: """Poll the task status until it is finished then get the response.""" - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_GET_TASK_STATUS}/{task_id}", @@ -559,7 +559,7 @@ class RunwayTextToImageNode(ComfyNodeABC): node_id=node_id, ) - def api_call( + async def api_call( self, prompt: str, ratio: str, @@ -574,7 +574,7 @@ class RunwayTextToImageNode(ComfyNodeABC): reference_images = None if reference_image is not None: validate_input_image(reference_image) - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( reference_image, max_images=1, mime_type="image/png", @@ -605,19 +605,19 @@ class RunwayTextToImageNode(ComfyNodeABC): auth_kwargs=kwargs, ) - initial_response = initial_operation.execute() + initial_response = await initial_operation.execute() self.validate_task_created(initial_response) task_id = initial_response.id # Poll for completion - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) self.validate_response(final_response) # Download and return image image_url = get_image_url_from_task_status(final_response) - return (download_url_to_image_tensor(image_url),) + return (await download_url_to_image_tensor(image_url),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 02e421678..31309d831 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -124,7 +124,7 @@ class StabilityStableImageUltraNode: }, } - def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, + async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, **kwargs): validate_string(prompt, strip_whitespace=False) @@ -163,7 +163,7 @@ class StabilityStableImageUltraNode: content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.") @@ -257,7 +257,7 @@ class StabilityStableImageSD_3_5Node: }, } - def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float, + async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float, negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, **kwargs): validate_string(prompt, strip_whitespace=False) @@ -302,7 +302,7 @@ class StabilityStableImageSD_3_5Node: content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.") @@ -374,7 +374,7 @@ class StabilityUpscaleConservativeNode: }, } - def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None, + async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None, **kwargs): validate_string(prompt, strip_whitespace=False) image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() @@ -403,7 +403,7 @@ class StabilityUpscaleConservativeNode: content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.") @@ -480,7 +480,7 @@ class StabilityUpscaleCreativeNode: }, } - def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None, + async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None, **kwargs): validate_string(prompt, strip_whitespace=False) image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() @@ -512,7 +512,7 @@ class StabilityUpscaleCreativeNode: content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -527,7 +527,7 @@ class StabilityUpscaleCreativeNode: status_extractor=lambda x: get_async_dummy_status(x), auth_kwargs=kwargs, ) - response_poll: StabilityResultsGetResponse = operation.execute() + response_poll: StabilityResultsGetResponse = await operation.execute() if response_poll.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.") @@ -563,8 +563,7 @@ class StabilityUpscaleFastNode: }, } - def api_call(self, image: torch.Tensor, - **kwargs): + async def api_call(self, image: torch.Tensor, **kwargs): image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read() files = { @@ -583,7 +582,7 @@ class StabilityUpscaleFastNode: content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.") diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 65f3b21f5..d08cf9007 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -37,8 +37,8 @@ from comfy_api_nodes.apinode_utils import ( ) -def upload_image_to_tripo(image, **kwargs): - urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs) +async def upload_image_to_tripo(image, **kwargs): + urls = await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs) return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg")) def get_model_url_from_response(response: TripoTaskResponse) -> str: @@ -49,7 +49,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str: raise RuntimeError(f"Failed to get model url from response: {response}") -def poll_until_finished( +async def poll_until_finished( kwargs: dict[str, str], response: TripoTaskResponse, ) -> tuple[str, str]: @@ -57,7 +57,7 @@ def poll_until_finished( if response.code != 0: raise RuntimeError(f"Failed to generate mesh: {response.error}") task_id = response.data.task_id - response_poll = PollingOperation( + response_poll = await PollingOperation( poll_endpoint=ApiEndpoint( path=f"/proxy/tripo/v2/openapi/task/{task_id}", method=HttpMethod.GET, @@ -80,7 +80,7 @@ def poll_until_finished( ).execute() if response_poll.data.status == TripoTaskStatus.SUCCESS: url = get_model_url_from_response(response_poll) - bytesio = download_url_to_bytesio(url) + bytesio = await download_url_to_bytesio(url) # Save the downloaded model file model_file = f"tripo_model_{task_id}.glb" with open(os.path.join(get_output_directory(), model_file), "wb") as f: @@ -88,6 +88,7 @@ def poll_until_finished( return model_file, task_id raise RuntimeError(f"Failed to generate mesh: {response_poll}") + class TripoTextToModelNode: """ Generates 3D models synchronously based on a text prompt using Tripo's API. @@ -126,11 +127,11 @@ class TripoTextToModelNode: API_NODE = True OUTPUT_NODE = True - def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): style_enum = None if style == "None" else style if not prompt: raise RuntimeError("Prompt is required") - response = SynchronousOperation( + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -155,7 +156,8 @@ class TripoTextToModelNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoImageToModelNode: """ @@ -195,12 +197,12 @@ class TripoImageToModelNode: API_NODE = True OUTPUT_NODE = True - def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + async def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): style_enum = None if style == "None" else style if image is None: raise RuntimeError("Image is required") - tripo_file = upload_image_to_tripo(image, **kwargs) - response = SynchronousOperation( + tripo_file = await upload_image_to_tripo(image, **kwargs) + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -225,7 +227,8 @@ class TripoImageToModelNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoMultiviewToModelNode: """ @@ -267,7 +270,7 @@ class TripoMultiviewToModelNode: API_NODE = True OUTPUT_NODE = True - def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs): + async def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs): if image is None: raise RuntimeError("front image for multiview is required") images = [] @@ -282,11 +285,11 @@ class TripoMultiviewToModelNode: for image_name in ["image", "image_left", "image_back", "image_right"]: image_ = image_dict[image_name] if image_ is not None: - tripo_file = upload_image_to_tripo(image_, **kwargs) + tripo_file = await upload_image_to_tripo(image_, **kwargs) images.append(tripo_file) else: images.append(TripoFileEmptyReference()) - response = SynchronousOperation( + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -309,7 +312,8 @@ class TripoMultiviewToModelNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoTextureNode: @classmethod @@ -340,8 +344,8 @@ class TripoTextureNode: OUTPUT_NODE = True AVERAGE_DURATION = 80 - def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs): - response = SynchronousOperation( + async def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs): + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -358,7 +362,7 @@ class TripoTextureNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) class TripoRefineNode: @@ -387,8 +391,8 @@ class TripoRefineNode: OUTPUT_NODE = True AVERAGE_DURATION = 240 - def generate_mesh(self, model_task_id, **kwargs): - response = SynchronousOperation( + async def generate_mesh(self, model_task_id, **kwargs): + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -400,7 +404,7 @@ class TripoRefineNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) class TripoRigNode: @@ -425,8 +429,8 @@ class TripoRigNode: OUTPUT_NODE = True AVERAGE_DURATION = 180 - def generate_mesh(self, original_model_task_id, **kwargs): - response = SynchronousOperation( + async def generate_mesh(self, original_model_task_id, **kwargs): + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -440,7 +444,8 @@ class TripoRigNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoRetargetNode: @classmethod @@ -475,8 +480,8 @@ class TripoRetargetNode: OUTPUT_NODE = True AVERAGE_DURATION = 30 - def generate_mesh(self, animation, original_model_task_id, **kwargs): - response = SynchronousOperation( + async def generate_mesh(self, animation, original_model_task_id, **kwargs): + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -491,7 +496,8 @@ class TripoRetargetNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoConversionNode: @classmethod @@ -529,10 +535,10 @@ class TripoConversionNode: OUTPUT_NODE = True AVERAGE_DURATION = 30 - def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs): + async def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs): if not original_model_task_id: raise RuntimeError("original_model_task_id is required") - response = SynchronousOperation( + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -549,7 +555,8 @@ class TripoConversionNode: ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + NODE_CLASS_MAPPINGS = { "TripoTextToModelNode": TripoTextToModelNode, diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 97bfe20e6..e25dab2f5 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,7 +1,7 @@ import io import logging import base64 -import requests +import aiohttp import torch from typing import Optional @@ -152,7 +152,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API" API_NODE = True - def generate_video( + async def generate_video( self, prompt, aspect_ratio="16:9", @@ -217,7 +217,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): auth_kwargs=kwargs, ) - initial_response = initial_operation.execute() + initial_response = await initial_operation.execute() operation_name = initial_response.name logging.info(f"Veo generation started with operation name: {operation_name}") @@ -256,7 +256,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): ) # Execute the polling operation - poll_response = poll_operation.execute() + poll_response = await poll_operation.execute() # Now check for errors in the final response # Check for error in poll response @@ -281,7 +281,6 @@ class VeoVideoGenerationNode(ComfyNodeABC): raise Exception(error_message) # Extract video data - video_data = None if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0: video = poll_response.response.videos[0] @@ -291,9 +290,9 @@ class VeoVideoGenerationNode(ComfyNodeABC): video_data = base64.b64decode(video.bytesBase64Encoded) elif hasattr(video, 'gcsUri') and video.gcsUri: # Download from URL - video_url = video.gcsUri - video_response = requests.get(video_url) - video_data = video_response.content + async with aiohttp.ClientSession() as session: + async with session.get(video.gcsUri) as video_response: + video_data = await video_response.content.read() else: raise Exception("Video returned but no data or URL was provided") else: From 735bb4bdb186bd4f39b9c924c24b8b39a7ef8b0d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 8 Aug 2025 01:21:00 -0700 Subject: [PATCH 03/71] Users report gfx1201 is buggy on flux with pytorch attention. (#9244) --- comfy/model_management.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9e6149d60..dc5b4711d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -321,9 +321,9 @@ try: if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True - if torch_version_numeric >= (2, 8): - if any((a in arch) for a in ["gfx1201"]): - ENABLE_PYTORCH_ATTENTION = True +# if torch_version_numeric >= (2, 8): +# if any((a in arch) for a in ["gfx1201"]): +# ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches SUPPORT_FP8_OPS = True From 5828607ccfef82a82931d8b66f3fd1176e04588f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 9 Aug 2025 09:49:25 -0700 Subject: [PATCH 04/71] Not sure if AMD actually support fp16 acc but it doesn't crash. (#9258) --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index dc5b4711d..c08f759e5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -340,7 +340,7 @@ if ENABLE_PYTORCH_ATTENTION: PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other try: - if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast: + if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast: torch.backends.cuda.matmul.allow_fp16_accumulation = True PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance logging.info("Enabled fp16 accumulation.") From 0552de7c7d6bcdd515da115d6756fd30494c7ff4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 10 Aug 2025 02:03:47 -0700 Subject: [PATCH 05/71] Bump pytorch cuda and rocm versions in readme instructions. (#9273) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 119098f5c..e4cff01a9 100644 --- a/README.md +++ b/README.md @@ -203,7 +203,7 @@ Put your VAE in: models/vae ### AMD GPUs (Linux only) AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: -```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3``` +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` This is the command to install the nightly with ROCm 6.4 which might have some performance improvements: @@ -237,7 +237,7 @@ Additional discussion and help can be found [here](https://github.com/comfyanony Nvidia users should install stable pytorch using this command: -```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128``` +```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129``` This is the command to install pytorch nightly instead which might have performance improvements. From 966f3a52061b5e300f36c6de0d07c47d6ad12f76 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 11 Aug 2025 02:53:01 -0700 Subject: [PATCH 06/71] Only show feature flags log when verbose. (#9281) --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 0553a0dd7..8f9c88ebf 100644 --- a/server.py +++ b/server.py @@ -235,7 +235,7 @@ class PromptServer(): sid, ) - logging.info( + logging.debug( f"Feature flags negotiated for client {sid}: {client_flags}" ) first_message = False From fa340add552497a264071fd7f6c407ff4aa10449 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 11 Aug 2025 23:48:17 +0300 Subject: [PATCH 07/71] remove creation of non-used asyncio_loop (#9284) --- execution.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/execution.py b/execution.py index 952f0cc5c..1dc35738b 100644 --- a/execution.py +++ b/execution.py @@ -646,8 +646,6 @@ class PromptExecutor: self.add_message("execution_error", mes, broadcast=False) def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): - asyncio_loop = asyncio.new_event_loop() - asyncio.set_event_loop(asyncio_loop) asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): From 629b17383718e1f46dbba101ea83ec897fbe3082 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 12 Aug 2025 04:52:12 +0800 Subject: [PATCH 08/71] Update template & embedded docs (#9283) * Update template & embedded docs * Update embedded docs to 0.2.6 --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2f4692b03..2fb38ef27 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ comfyui-frontend-package==1.24.4 -comfyui-workflow-templates==0.1.52 -comfyui-embedded-docs==0.2.4 +comfyui-workflow-templates==0.1.53 +comfyui-embedded-docs==0.2.6 torch torchsde torchvision From 2208aa616d3ad193cd37ef57076d4f5243cecdd3 Mon Sep 17 00:00:00 2001 From: PsychoLogicAu Date: Tue, 12 Aug 2025 06:56:16 +1000 Subject: [PATCH 09/71] Support SimpleTuner lycoris lora for Qwen-Image (#9280) --- comfy/lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/lora.py b/comfy/lora.py index 6686b7229..00358884b 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -301,6 +301,7 @@ def model_lora_keys_unet(model, key_map={}): key_map["{}".format(key_lora)] = k # Support transformer prefix format key_map["transformer.{}".format(key_lora)] = k + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format return key_map From f4231a80b1b904b45ade0def9b37320c4adfe71b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 12 Aug 2025 00:15:14 +0300 Subject: [PATCH 10/71] fix(Kling Image API Node): do not pass "image_type" when no image (#9271) * fix(Kling Image API Node): do not pass "image_type" when no image * fix(Kling Image API Node): raise client-side error when kling_v1 is used with reference image --- comfy_api_nodes/nodes_kling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9d9eb5628..9d483bb0e 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -1690,7 +1690,11 @@ class KlingImageGenerationNode(KlingImageGenerationBase): ): self.validate_prompt(prompt, negative_prompt) - if image is not None: + if image is None: + image_type = None + elif model_name == KlingImageGenModelName.kling_v1: + raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.") + else: image = tensor_to_base64_string(image) initial_operation = SynchronousOperation( From 1e3ae1eed8b925430e3b114ea6b7d08ea698e305 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 13 Aug 2025 05:14:27 +0800 Subject: [PATCH 11/71] Update template to 0.1.58 (#9302) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2fb38ef27..82af5690b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.24.4 -comfyui-workflow-templates==0.1.53 +comfyui-workflow-templates==0.1.58 comfyui-embedded-docs==0.2.6 torch torchsde From e1d4f36d8df7446ebb1a5f2bf9c708c38a159f22 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 12 Aug 2025 17:13:04 -0700 Subject: [PATCH 12/71] Update test release package workflow with python 3.13 cu129. (#9306) --- .github/workflows/windows_release_package.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 3926a65f3..b51746285 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -7,19 +7,19 @@ on: description: 'cuda version' required: true type: string - default: "128" + default: "129" python_minor: description: 'python minor version' required: true type: string - default: "12" + default: "13" python_patch: description: 'python patch version' required: true type: string - default: "10" + default: "6" # push: # branches: # - master @@ -64,6 +64,8 @@ jobs: ./python.exe get-pip.py ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth + + rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd From 560d38f34c5bd532f89f2178f01ee819cf145820 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 12 Aug 2025 20:26:33 -0700 Subject: [PATCH 13/71] Wan2.2 fun control support. (#9292) --- comfy/ldm/wan/model.py | 19 +++++++++++++ comfy/model_base.py | 10 ++++++- comfy/model_detection.py | 5 ++++ comfy_extras/nodes_wan.py | 58 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 86d0795e9..4e2d99566 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -391,6 +391,7 @@ class WanModel(torch.nn.Module): cross_attn_norm=True, eps=1e-6, flf_pos_embed_token_number=None, + in_dim_ref_conv=None, image_model=None, device=None, dtype=None, @@ -484,6 +485,11 @@ class WanModel(torch.nn.Module): else: self.img_emb = None + if in_dim_ref_conv is not None: + self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + else: + self.ref_conv = None + def forward_orig( self, x, @@ -526,6 +532,13 @@ class WanModel(torch.nn.Module): e = e.reshape(t.shape[0], -1, e.shape[-1]) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + full_ref = None + if self.ref_conv is not None: + full_ref = kwargs.get("reference_latent", None) + if full_ref is not None: + full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) + x = torch.concat((full_ref, x), dim=1) + # context context = self.text_embedding(context) @@ -552,6 +565,9 @@ class WanModel(torch.nn.Module): # head x = self.head(x, e) + if full_ref is not None: + x = x[:, full_ref.shape[1]:] + # unpatchify x = self.unpatchify(x, grid_sizes) return x @@ -570,6 +586,9 @@ class WanModel(torch.nn.Module): x = torch.cat([x, time_dim_concat], dim=2) t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0]) + if self.ref_conv is not None and "reference_latent" in kwargs: + t_len += 1 + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8a2d9cbe6..cde61df7c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1124,7 +1124,11 @@ class WAN21(BaseModel): mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) - return torch.cat((mask, image), dim=1) + concat_mask_index = kwargs.get("concat_mask_index", 0) + if concat_mask_index != 0: + return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1) + else: + return torch.cat((mask, image), dim=1) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1140,6 +1144,10 @@ class WAN21(BaseModel): if time_dim_concat is not None: out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat)) + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0]) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8b57ebd2f..8acc51e20 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -373,6 +373,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix)) if flf_weight is not None: dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1] + + ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix)) + if ref_conv_weight is not None: + dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] + return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 0067d054d..f80c83ba6 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -103,6 +103,63 @@ class WanFunControlToVideo: out_latent["samples"] = latent return (positive, negative, out_latent) +class Wan22FunControlToVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }, + "optional": {"ref_image": ("IMAGE", ), + "control_video": ("IMAGE", ), + # "start_image": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None): + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) + concat_latent = concat_latent.repeat(1, 2, 1, 1, 1) + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(start_image[:, :, :, :3]) + concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + mask[:, :, :start_image.shape[0] + 3] = 0.0 + + ref_latent = None + if ref_image is not None: + ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + ref_latent = vae.encode(ref_image[:, :, :, :3]) + + if control_video is not None: + control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + concat_latent_image = vae.encode(control_video[:, :, :, :3]) + concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16}) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + out_latent = {} + out_latent["samples"] = latent + return (positive, negative, out_latent) + class WanFirstLastFrameToVideo: @classmethod def INPUT_TYPES(s): @@ -733,6 +790,7 @@ NODE_CLASS_MAPPINGS = { "WanTrackToVideo": WanTrackToVideo, "WanImageToVideo": WanImageToVideo, "WanFunControlToVideo": WanFunControlToVideo, + "Wan22FunControlToVideo": Wan22FunControlToVideo, "WanFunInpaintToVideo": WanFunInpaintToVideo, "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, "WanVaceToVideo": WanVaceToVideo, From 898d88e10e45f38500ca6044280bab4ca2f2d273 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Aug 2025 20:34:58 -0700 Subject: [PATCH 14/71] Make torchaudio exception catching less specific (#9309) --- comfy_api/latest/_ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 61597038f..26a55615f 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -12,7 +12,7 @@ import torch try: import torchaudio TORCH_AUDIO_AVAILABLE = True -except ImportError: +except: TORCH_AUDIO_AVAILABLE = False from PIL import Image as PILImage from PIL.PngImagePlugin import PngInfo From 3294782d19c3af0c6166aafe0465fe6b59571d17 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 13 Aug 2025 14:50:50 +0800 Subject: [PATCH 15/71] Update template to 0.1.59 (#9313) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 82af5690b..bfb31a73f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.24.4 -comfyui-workflow-templates==0.1.58 +comfyui-workflow-templates==0.1.59 comfyui-embedded-docs==0.2.6 torch torchsde From 5ca8e2fac3b6826261c5991b0663b69eff60b3a1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 00:01:12 -0700 Subject: [PATCH 16/71] Update release workflow to python3.13 pytorch cu129 (#9315) * Try to reduce size of portable even more. * Update stable release workflow to python 3.13 cu129 * Update dependencies workflow to python3.13 cu129 --- .github/workflows/stable-release.yml | 15 ++++++++++----- .../workflows/windows_release_dependencies.yml | 6 +++--- .github/workflows/windows_release_package.yml | 2 ++ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 61105abe4..a5a1ed2d0 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -12,17 +12,17 @@ on: description: 'CUDA version' required: true type: string - default: "128" + default: "129" python_minor: description: 'Python minor version' required: true type: string - default: "12" + default: "13" python_patch: description: 'Python patch version' required: true type: string - default: "10" + default: "6" jobs: @@ -66,8 +66,13 @@ jobs: curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py ./python.exe get-pip.py ./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* - sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth - cd .. + sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth + + rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space + rm ./Lib/site-packages/torch/lib/libprotoc.lib + rm ./Lib/site-packages/torch/lib/libprotobuf.lib + + cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd cp taesd/*.safetensors ./ComfyUI_copy/models/vae_approx/ diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index dfdb96d50..7761cc1ed 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -17,19 +17,19 @@ on: description: 'cuda version' required: true type: string - default: "128" + default: "129" python_minor: description: 'python minor version' required: true type: string - default: "12" + default: "13" python_patch: description: 'python patch version' required: true type: string - default: "10" + default: "6" # push: # branches: # - master diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index b51746285..3334e6839 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -66,6 +66,8 @@ jobs: sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth rm ./Lib/site-packages/torch/lib/dnnl.lib #I don't think this is actually used and I need the space + rm ./Lib/site-packages/torch/lib/libprotoc.lib + rm ./Lib/site-packages/torch/lib/libprotobuf.lib cd .. git clone --depth 1 https://github.com/comfyanonymous/taesd From e400f26c8fc9867248394616a4b58ecc4c53fbfd Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 00:44:54 -0700 Subject: [PATCH 17/71] Downgrade frontend for release. (#9316) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bfb31a73f..56ed85e01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.24.4 +comfyui-frontend-package==1.23.4 comfyui-workflow-templates==0.1.59 comfyui-embedded-docs==0.2.6 torch From d5c1954d5cd4a789bbf84d2b75a955a5a3f93de8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 13 Aug 2025 03:46:38 -0400 Subject: [PATCH 18/71] ComfyUI version 0.3.50 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 5e2d09c81..29ec07ca6 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.49" +__version__ = "0.3.50" diff --git a/pyproject.toml b/pyproject.toml index 3c530ba85..659b5730a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.49" +version = "0.3.50" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From 615eb52049df98cebdd67bc672b66dc059171d7c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 00:48:06 -0700 Subject: [PATCH 19/71] Put back frontend version. (#9317) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 56ed85e01..bfb31a73f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.23.4 +comfyui-frontend-package==1.24.4 comfyui-workflow-templates==0.1.59 comfyui-embedded-docs==0.2.6 torch From afa0a45206832b0e64e38454b7841d1da7ca56e4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 11:42:08 -0700 Subject: [PATCH 20/71] Reduce portable size again. (#9323) * compress more * test * not needed --- .github/workflows/stable-release.yml | 2 +- .github/workflows/windows_release_package.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index a5a1ed2d0..2bc8e5905 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -90,7 +90,7 @@ jobs: cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable mv ComfyUI_windows_portable.7z ComfyUI/ComfyUI_windows_portable_nvidia.7z cd ComfyUI_windows_portable diff --git a/.github/workflows/windows_release_package.yml b/.github/workflows/windows_release_package.yml index 3334e6839..46375698e 100644 --- a/.github/workflows/windows_release_package.yml +++ b/.github/workflows/windows_release_package.yml @@ -86,7 +86,7 @@ jobs: cd .. - "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=512m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable + "C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable mv ComfyUI_windows_portable.7z ComfyUI/new_ComfyUI_windows_portable_nvidia_cu${{ inputs.cu }}_or_cpu.7z cd ComfyUI_windows_portable From 3da5a07510794c37d437cbea1d94065bb0aa8ebc Mon Sep 17 00:00:00 2001 From: contentis Date: Wed, 13 Aug 2025 20:53:27 +0200 Subject: [PATCH 21/71] SDPA backend priority (#9299) --- comfy/ldm/hunyuan3d/vae.py | 2 +- comfy/ldm/modules/attention.py | 4 ++-- comfy/ldm/modules/diffusionmodules/model.py | 2 +- comfy/ops.py | 13 +++++++++++++ 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index 5eb2c6548..bea6090a2 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module): class CrossAttentionProcessor: def __call__(self, attn, q, k, v): - out = F.scaled_dot_product_attention(q, k, v) + out = ops.scaled_dot_product_attention(q, k, v) return out diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 35d2270ee..19c3c7af1 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha mask = mask.unsqueeze(1) if SDP_BATCH_LIMIT >= b: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.shape[0] > 1: m = mask[i : i + SDP_BATCH_LIMIT] - out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention( + out[i : i + SDP_BATCH_LIMIT] = ops.scaled_dot_product_attention( q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 5c0373b74..79160412f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -285,7 +285,7 @@ def pytorch_attention(q, k, v): ) try: - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") diff --git a/comfy/ops.py b/comfy/ops.py index 2cc9bbc27..8b7b662b6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,9 +23,18 @@ from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import contextlib +from torch.nn.attention import SDPBackend, sdpa_kernel cast_to = comfy.model_management.cast_to #TODO: remove once no more references +SDPA_BACKEND_PRIORITY = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, +] +if torch.cuda.is_available(): + SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) + def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -249,6 +258,10 @@ class disable_weight_init: else: raise ValueError(f"unsupported dimensions: {dims}") + @staticmethod + @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) + def scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): From 9df8792d4b894a8ea8034414ef63f70deee4b1af Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 13 Aug 2025 12:12:41 -0700 Subject: [PATCH 22/71] Make last PR not crash comfy on old pytorch. (#9324) --- comfy/ldm/hunyuan3d/vae.py | 2 +- comfy/ldm/modules/attention.py | 4 +-- comfy/ldm/modules/diffusionmodules/model.py | 2 +- comfy/ops.py | 36 +++++++++++++-------- 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index bea6090a2..6e8cbf1d9 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module): class CrossAttentionProcessor: def __call__(self, attn, q, k, v): - out = ops.scaled_dot_product_attention(q, k, v) + out = comfy.ops.scaled_dot_product_attention(q, k, v) return out diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 19c3c7af1..043df28df 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -448,7 +448,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha mask = mask.unsqueeze(1) if SDP_BATCH_LIMIT >= b: - out = ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -461,7 +461,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.shape[0] > 1: m = mask[i : i + SDP_BATCH_LIMIT] - out[i : i + SDP_BATCH_LIMIT] = ops.scaled_dot_product_attention( + out[i : i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention( q[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT], diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 79160412f..1fd12b35a 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -285,7 +285,7 @@ def pytorch_attention(q, k, v): ) try: - out = ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") diff --git a/comfy/ops.py b/comfy/ops.py index 8b7b662b6..be312d714 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,18 +23,32 @@ from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import contextlib -from torch.nn.attention import SDPBackend, sdpa_kernel + + +def scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + + +try: + if torch.cuda.is_available(): + from torch.nn.attention import SDPBackend, sdpa_kernel + + SDPA_BACKEND_PRIORITY = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] + + SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) + + @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) + def scaled_dot_product_attention(q, k, v, *args, **kwargs): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) +except (ModuleNotFoundError, TypeError): + logging.warning("Could not set sdpa backend priority.") cast_to = comfy.model_management.cast_to #TODO: remove once no more references -SDPA_BACKEND_PRIORITY = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.MATH, -] -if torch.cuda.is_available(): - SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -258,10 +272,6 @@ class disable_weight_init: else: raise ValueError(f"unsupported dimensions: {dims}") - @staticmethod - @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) - def scaled_dot_product_attention(q, k, v, *args, **kwargs): - return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): From c991a5da658667cf29f2916bef096fa7b18afd47 Mon Sep 17 00:00:00 2001 From: Simon Lui <502929+simonlui@users.noreply.github.com> Date: Wed, 13 Aug 2025 16:13:35 -0700 Subject: [PATCH 23/71] Fix XPU iGPU regressions (#9322) * Change bf16 check and switch non-blocking to off default with option to force to regain speed on certain classes of iGPUs and refactor xpu check. * Turn non_blocking off by default for xpu. * Update README.md for Intel GPUs. --- README.md | 28 ++++++++++------------------ comfy/cli_args.py | 2 ++ comfy/model_management.py | 21 +++++++++++++-------- 3 files changed, 25 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index e4cff01a9..fa99a8cbe 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a ## Get Started #### [Desktop Application](https://www.comfy.org/download) -- The easiest way to get started. +- The easiest way to get started. - Available on Windows & macOS. #### [Windows Portable Package](#installing) @@ -211,27 +211,19 @@ This is the command to install the nightly with ROCm 6.4 which might have some p ### Intel GPUs (Windows and Linux) -(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html) - -1. To install PyTorch nightly, use the following command: +(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html) + +1. To install PyTorch xpu, use the following command: + +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu``` + +This is the command to install the Pytorch xpu nightly which might have some performance improvements: ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu``` -2. Launch ComfyUI by running `python main.py` - - (Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance. -1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below: - -``` -conda install libuv -pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/ -``` - -For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information. - -Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476). +1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information. ### NVIDIA @@ -352,7 +344,7 @@ Generate a self-signed certificate (not appropriate for shared/production use) a Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app will now be accessible with `https://...` instead of `http://...`. -> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above. +> Note: Windows users can use [alexisrolland/docker-openssl](https://github.com/alexisrolland/docker-openssl) or one of the [3rd party binary distributions](https://wiki.openssl.org/index.php/Binaries) to run the command example above.

If you use a container, note that the volume mount `-v` can be a relative path so `... -v ".\:/openssl-certs" ...` would create the key & cert files in the current directory of your command prompt or powershell terminal. ## Support and dev channel diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 0d760d524..de3e85c08 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -132,6 +132,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.") +parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") + parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") diff --git a/comfy/model_management.py b/comfy/model_management.py index c08f759e5..2a9f18068 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -78,7 +78,6 @@ try: torch_version = torch.version.__version__ temp = torch_version.split(".") torch_version_numeric = (int(temp[0]), int(temp[1])) - xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available() except: pass @@ -102,10 +101,14 @@ if args.directml is not None: try: import intel_extension_for_pytorch as ipex # noqa: F401 - _ = torch.xpu.device_count() - xpu_available = xpu_available or torch.xpu.is_available() except: - xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available()) + pass + +try: + _ = torch.xpu.device_count() + xpu_available = torch.xpu.is_available() +except: + xpu_available = False try: if torch.backends.mps.is_available(): @@ -946,10 +949,12 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None): return dtype def device_supports_non_blocking(device): + if args.force_non_blocking: + return True if is_device_mps(device): return False #pytorch bug? mps doesn't support non blocking - if is_intel_xpu(): - return True + if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes + return False if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews) return False if directml_enabled: @@ -1282,10 +1287,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma return False if is_intel_xpu(): - if torch_version_numeric < (2, 6): + if torch_version_numeric < (2, 3): return True else: - return torch.xpu.get_device_capability(device)['has_bfloat16_conversions'] + return torch.xpu.is_bf16_supported() if is_ascend_npu(): return True From e4f7ea105f4b3034593f316560d952b80453e344 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 13 Aug 2025 18:33:05 -0700 Subject: [PATCH 24/71] Added context window support to core sampling code (#9238) * Added initial support for basic context windows - in progress * Add prepare_sampling wrapper for context window to more accurately estimate latent memory requirements, fixed merging wrappers/callbacks dicts in prepare_model_patcher * Made context windows compatible with different dimensions; works for WAN, but results are bad * Fix comfy.patcher_extension.merge_nested_dicts calls in prepare_model_patcher in sampler_helpers.py * Considering adding some callbacks to context window code to allow extensions of behavior without the need to rewrite code * Made dim slicing cleaner * Add Wan Context WIndows node for testing * Made context schedule and fuse method functions be stored on the handler instead of needing to be registered in core code to be found * Moved some code around between node_context_windows.py and context_windows.py * Change manual context window nodes names/ids * Added callbacks to IndexListContexHandler * Adjusted default values for context_length and context_overlap, made schema.inputs definition for WAN Context Windows less annoying * Make get_resized_cond more robust for various dim sizes * Fix typo * Another small fix --- comfy/context_windows.py | 537 ++++++++++++++++++++++++++ comfy/sampler_helpers.py | 6 +- comfy/samplers.py | 11 +- comfy_extras/nodes_context_windows.py | 89 +++++ nodes.py | 1 + 5 files changed, 639 insertions(+), 5 deletions(-) create mode 100644 comfy/context_windows.py create mode 100644 comfy_extras/nodes_context_windows.py diff --git a/comfy/context_windows.py b/comfy/context_windows.py new file mode 100644 index 000000000..928b111df --- /dev/null +++ b/comfy/context_windows.py @@ -0,0 +1,537 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Callable +import torch +import numpy as np +import collections +from dataclasses import dataclass +from abc import ABC, abstractmethod +import logging +import comfy.model_management +import comfy.patcher_extension +if TYPE_CHECKING: + from comfy.model_base import BaseModel + from comfy.model_patcher import ModelPatcher + from comfy.controlnet import ControlBase + + +class ContextWindowABC(ABC): + def __init__(self): + ... + + @abstractmethod + def get_tensor(self, full: torch.Tensor) -> torch.Tensor: + """ + Get torch.Tensor applicable to current window. + """ + raise NotImplementedError("Not implemented.") + + @abstractmethod + def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor: + """ + Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy. + """ + raise NotImplementedError("Not implemented.") + +class ContextHandlerABC(ABC): + def __init__(self): + ... + + @abstractmethod + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: + raise NotImplementedError("Not implemented.") + + @abstractmethod + def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list: + raise NotImplementedError("Not implemented.") + + @abstractmethod + def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + raise NotImplementedError("Not implemented.") + + + +class IndexListContextWindow(ContextWindowABC): + def __init__(self, index_list: list[int], dim: int=0): + self.index_list = index_list + self.context_length = len(index_list) + self.dim = dim + + def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor: + if dim is None: + dim = self.dim + if dim == 0 and full.shape[dim] == 1: + return full + idx = [slice(None)] * dim + [self.index_list] + return full[idx].to(device) + + def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor: + if dim is None: + dim = self.dim + idx = [slice(None)] * dim + [self.index_list] + full[idx] += to_add + return full + + +class IndexListCallbacks: + EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" + COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results" + EXECUTE_START = "execute_start" + EXECUTE_CLEANUP = "execute_cleanup" + + def init_callbacks(self): + return {} + + +@dataclass +class ContextSchedule: + name: str + func: Callable + +@dataclass +class ContextFuseMethod: + name: str + func: Callable + +ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) +class IndexListContextHandler(ContextHandlerABC): + def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0): + self.context_schedule = context_schedule + self.fuse_method = fuse_method + self.context_length = context_length + self.context_overlap = context_overlap + self.context_stride = context_stride + self.closed_loop = closed_loop + self.dim = dim + self._step = 0 + + self.callbacks = {} + + def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: + # for now, assume first dim is batch - should have stored on BaseModel in actual implementation + if x_in.size(self.dim) > self.context_length: + logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.") + return True + return False + + def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase: + if control.previous_controlnet is not None: + self.prepare_control_objects(control.previous_controlnet, device) + return control + + def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list: + if cond_in is None: + return None + # reuse or resize cond items to match context requirements + resized_cond = [] + # cond object is a list containing a dict - outer list is irrelevant, so just loop through it + for actual_cond in cond_in: + resized_actual_cond = actual_cond.copy() + # now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary + for key in actual_cond: + try: + cond_item = actual_cond[key] + if isinstance(cond_item, torch.Tensor): + # check that tensor is the expected length - x.size(0) + if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim): + # if so, it's subsetting time - tell controls the expected indeces so they can handle them + actual_cond_item = window.get_tensor(cond_item) + resized_actual_cond[key] = actual_cond_item.to(device) + else: + resized_actual_cond[key] = cond_item.to(device) + # look for control + elif key == "control": + resized_actual_cond[key] = self.prepare_control_objects(cond_item, device) + elif isinstance(cond_item, dict): + new_cond_item = cond_item.copy() + # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) + for cond_key, cond_value in new_cond_item.items(): + if isinstance(cond_value, torch.Tensor): + if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim): + new_cond_item[cond_key] = window.get_tensor(cond_value, device) + # if has cond that is a Tensor, check if needs to be subset + elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim): + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device)) + elif cond_key == "num_video_frames": # for SVD + new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond) + new_cond_item[cond_key].cond = window.context_length + resized_actual_cond[key] = new_cond_item + else: + resized_actual_cond[key] = cond_item + finally: + del cond_item # just in case to prevent VRAM issues + resized_cond.append(resized_actual_cond) + return resized_cond + + def set_step(self, timestep: torch.Tensor, model_options: dict[str]): + indexes = torch.where(model_options["transformer_options"]["sample_sigmas"] == timestep[0]) + self._step = int(indexes[0]) + + def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: + full_length = x_in.size(self.dim) # TODO: choose dim based on model + context_windows = self.context_schedule.func(full_length, self, model_options) + context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows] + return context_windows + + def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + self.set_step(timestep, model_options) + context_windows = self.get_context_windows(model, x_in, model_options) + enumerated_context_windows = list(enumerate(context_windows)) + + conds_final = [torch.zeros_like(x_in) for _ in conds] + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + else: + counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds] + biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds] + + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options) + + for enum_window in enumerated_context_windows: + results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options) + for result in results: + self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep, + conds_final, counts_final, biases_final) + try: + # finalize conds + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + # relative is already normalized, so return as is + del counts_final + return conds_final + else: + # normalize conds via division by context usage counts + for i in range(len(conds_final)): + conds_final[i] /= counts_final[i] + del counts_final + return conds_final + finally: + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options) + + def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]], + model_options, device=None, first_device=None): + results: list[ContextResults] = [] + for window_idx, window in enumerated_context_windows: + # allow processing to end between context window executions for faster Cancel + comfy.model_management.throw_exception_if_processing_interrupted() + + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): + callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) + + # update exposed params + model_options["transformer_options"]["context_window"] = window + # get subsections of x, timestep, conds + sub_x = window.get_tensor(x_in, device) + sub_timestep = window.get_tensor(timestep, device, dim=0) + sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds] + + sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options) + if device is not None: + for i in range(len(sub_conds_out)): + sub_conds_out[i] = sub_conds_out[i].to(x_in.device) + results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) + return results + + + def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor, + conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]): + if self.fuse_method.name == ContextFuseMethods.RELATIVE: + for pos, idx in enumerate(window.index_list): + # bias is the influence of a specific index in relation to the whole context window + bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2) + bias = max(1e-2, bias) + # take weighted average relative to total bias of current idx + for i in range(len(sub_conds_out)): + bias_total = biases_final[i][idx] + prev_weight = (bias_total / (bias_total + bias)) + new_weight = (bias / (bias_total + bias)) + # account for dims of tensors + idx_window = [slice(None)] * self.dim + [idx] + pos_window = [slice(None)] * self.dim + [pos] + # apply new values + conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight + biases_final[i][idx] = bias_total + bias + else: + # add conds and counts based on weights of fuse method + weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep) + weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device) + for i in range(len(sub_conds_out)): + window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor) + window.add_window(counts_final[i], weights_tensor) + + for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks): + callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final) + + +def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs): + # limit noise_shape length to context_length for more accurate vram use estimation + model_options = kwargs.get("model_options", None) + if model_options is None: + raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.") + handler: IndexListContextHandler = model_options.get("context_handler", None) + if handler is not None: + noise_shape = list(noise_shape) + noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length) + return executor(model, noise_shape, *args, **kwargs) + + +def create_prepare_sampling_wrapper(model: ModelPatcher): + model.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, + "ContextWindows_prepare_sampling", + _prepare_sampling_wrapper + ) + + +def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: + total_dims = len(x_in.shape) + weights_tensor = torch.Tensor(weights).to(device=device) + for _ in range(dim): + weights_tensor = weights_tensor.unsqueeze(0) + for _ in range(total_dims - dim - 1): + weights_tensor = weights_tensor.unsqueeze(-1) + return weights_tensor + +def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]: + total_dims = len(x_in.shape) + shape = [] + for _ in range(dim): + shape.append(1) + shape.append(x_in.shape[dim]) + for _ in range(total_dims - dim - 1): + shape.append(1) + return shape + +class ContextSchedules: + UNIFORM_LOOPED = "looped_uniform" + UNIFORM_STANDARD = "standard_uniform" + STATIC_STANDARD = "standard_static" + BATCHED = "batched" + + +# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py +def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): + windows = [] + if num_frames < handler.context_length: + windows.append(list(range(num_frames))) + return windows + + context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1) + # obtain uniform windows as normal, looping and all + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(handler._step))) + for j in range( + int(ordered_halving(handler._step) * context_step) + pad, + num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap), + (handler.context_length * context_step - handler.context_overlap), + ): + windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)]) + + return windows + +def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): + # unlike looped, uniform_straight does NOT allow windows that loop back to the beginning; + # instead, they get shifted to the corresponding end of the frames. + # in the case that a window (shifted or not) is identical to the previous one, it gets skipped. + windows = [] + if num_frames <= handler.context_length: + windows.append(list(range(num_frames))) + return windows + + context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1) + # first, obtain uniform windows as normal, looping and all + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(handler._step))) + for j in range( + int(ordered_halving(handler._step) * context_step) + pad, + num_frames + pad + (-handler.context_overlap), + (handler.context_length * context_step - handler.context_overlap), + ): + windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)]) + + # now that windows are created, shift any windows that loop, and delete duplicate windows + delete_idxs = [] + win_i = 0 + while win_i < len(windows): + # if window is rolls over itself, need to shift it + is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) + if is_roll: + roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides + shift_window_to_end(windows[win_i], num_frames=num_frames) + # check if next window (cyclical) is missing roll_val + if roll_val not in windows[(win_i+1) % len(windows)]: + # need to insert new window here - just insert window starting at roll_val + windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length))) + # delete window if it's not unique + for pre_i in range(0, win_i): + if windows[win_i] == windows[pre_i]: + delete_idxs.append(win_i) + break + win_i += 1 + + # reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation + delete_idxs.reverse() + for i in delete_idxs: + windows.pop(i) + + return windows + + +def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): + windows = [] + if num_frames <= handler.context_length: + windows.append(list(range(num_frames))) + return windows + # always return the same set of windows + delta = handler.context_length - handler.context_overlap + for start_idx in range(0, num_frames, delta): + # if past the end of frames, move start_idx back to allow same context_length + ending = start_idx + handler.context_length + if ending >= num_frames: + final_delta = ending - num_frames + final_start_idx = start_idx - final_delta + windows.append(list(range(final_start_idx, final_start_idx + handler.context_length))) + break + windows.append(list(range(start_idx, start_idx + handler.context_length))) + return windows + + +def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]): + windows = [] + if num_frames <= handler.context_length: + windows.append(list(range(num_frames))) + return windows + # always return the same set of windows; + # no overlap, just cut up based on context_length; + # last window size will be different if num_frames % opts.context_length != 0 + for start_idx in range(0, num_frames, handler.context_length): + windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames)))) + return windows + + +def create_windows_default(num_frames: int, handler: IndexListContextHandler): + return [list(range(num_frames))] + + +CONTEXT_MAPPING = { + ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped, + ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard, + ContextSchedules.STATIC_STANDARD: create_windows_static_standard, + ContextSchedules.BATCHED: create_windows_batched, +} + + +def get_matching_context_schedule(context_schedule: str) -> ContextSchedule: + func = CONTEXT_MAPPING.get(context_schedule, None) + if func is None: + raise ValueError(f"Unknown context_schedule '{context_schedule}'.") + return ContextSchedule(context_schedule, func) + + +def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None): + return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs) + + +def create_weights_flat(length: int, **kwargs) -> list[float]: + # weight is the same for all + return [1.0] * length + +def create_weights_pyramid(length: int, **kwargs) -> list[float]: + # weight is based on the distance away from the edge of the context window; + # based on weighted average concept in FreeNoise paper + if length % 2 == 0: + max_weight = length // 2 + weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1)) + else: + max_weight = (length + 1) // 2 + weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) + return weight_sequence + +def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs): + # based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302 + # only expected overlap is given different weights + weights_torch = torch.ones((length)) + # blend left-side on all except first window + if min(idxs) > 0: + ramp_up = torch.linspace(1e-37, 1, handler.context_overlap) + weights_torch[:handler.context_overlap] = ramp_up + # blend right-side on all except last window + if max(idxs) < full_length-1: + ramp_down = torch.linspace(1, 1e-37, handler.context_overlap) + weights_torch[-handler.context_overlap:] = ramp_down + return weights_torch + +class ContextFuseMethods: + FLAT = "flat" + PYRAMID = "pyramid" + RELATIVE = "relative" + OVERLAP_LINEAR = "overlap-linear" + + LIST = [PYRAMID, FLAT, OVERLAP_LINEAR] + LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR] + + +FUSE_MAPPING = { + ContextFuseMethods.FLAT: create_weights_flat, + ContextFuseMethods.PYRAMID: create_weights_pyramid, + ContextFuseMethods.RELATIVE: create_weights_pyramid, + ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear, +} + +def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod: + func = FUSE_MAPPING.get(fuse_method, None) + if func is None: + raise ValueError(f"Unknown fuse_method '{fuse_method}'.") + return ContextFuseMethod(fuse_method, func) + +# Returns fraction that has denominator that is a power of 2 +def ordered_halving(val): + # get binary value, padded with 0s for 64 bits + bin_str = f"{val:064b}" + # flip binary value, padding included + bin_flip = bin_str[::-1] + # convert binary to int + as_int = int(bin_flip, 2) + # divide by 1 << 64, equivalent to 2**64, or 18446744073709551616, + # or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's) + return as_int / (1 << 64) + + +def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: + all_indexes = list(range(num_frames)) + for w in windows: + for val in w: + try: + all_indexes.remove(val) + except ValueError: + pass + return all_indexes + + +def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: + prev_val = -1 + for i, val in enumerate(window): + val = val % num_frames + if val < prev_val: + return True, i + prev_val = val + return False, -1 + + +def shift_window_to_start(window: list[int], num_frames: int): + start_val = window[0] + for i in range(len(window)): + # 1) subtract each element by start_val to move vals relative to the start of all frames + # 2) add num_frames and take modulus to get adjusted vals + window[i] = ((window[i] - start_val) + num_frames) % num_frames + + +def shift_window_to_end(window: list[int], num_frames: int): + # 1) shift window to start + shift_window_to_start(window, num_frames) + end_val = window[-1] + end_delta = num_frames - end_val - 1 + for i in range(len(window)): + # 2) add end_delta to each val to slide windows to end + window[i] = window[i] + end_delta diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 8dbc41455..e46971afb 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -149,7 +149,7 @@ def cleanup_models(conds, models): cleanup_additional_models(set(control_cleanup)) -def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): +def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict): ''' Registers hooks from conds. ''' @@ -158,8 +158,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict): for k in conds: get_hooks_from_cond(conds[k], hooks) # add wrappers and callbacks from ModelPatcher to transformer_options - model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers) - model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks) + comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("wrappers", {}), model.wrappers, copy_dict1=False) + comfy.patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("callbacks", {}), model.callbacks, copy_dict1=False) # begin registering hooks registered = comfy.hooks.HookGroup() target_dict = comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model) diff --git a/comfy/samplers.py b/comfy/samplers.py index ad2f40cdc..d5390d64e 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -16,6 +16,7 @@ import comfy.sampler_helpers import comfy.model_patcher import comfy.patcher_extension import comfy.hooks +import comfy.context_windows import scipy.stats import numpy @@ -198,14 +199,20 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H hooked_to_run.setdefault(p.hooks, list()) hooked_to_run[p.hooks] += [(p, i)] -def calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]): + handler: comfy.context_windows.ContextHandlerABC = model_options.get("context_handler", None) + if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options): + return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options) + return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options) + +def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _calc_cond_batch, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True) ) return executor.execute(model, conds, x_in, timestep, model_options) -def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): +def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): out_conds = [] out_counts = [] # separate conds by matching hooks diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py new file mode 100644 index 000000000..1c3d9e697 --- /dev/null +++ b/comfy_extras/nodes_context_windows.py @@ -0,0 +1,89 @@ +from __future__ import annotations +from comfy_api.latest import ComfyExtension, io +import comfy.context_windows +import nodes + + +class ContextWindowsManualNode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ContextWindowsManual", + display_name="Context Windows (Manual)", + category="context", + description="Manually set context windows.", + inputs=[ + io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), + io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."), + io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."), + io.Combo.Input("context_schedule", options=[ + comfy.context_windows.ContextSchedules.STATIC_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, + comfy.context_windows.ContextSchedules.BATCHED, + ], tooltip="The stride of the context window."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), + io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), + ], + outputs=[ + io.Model.Output(tooltip="The model with context windows applied during sampling."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model: + model = model.clone() + model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( + context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), + fuse_method=comfy.context_windows.get_matching_fuse_method(fuse_method), + context_length=context_length, + context_overlap=context_overlap, + context_stride=context_stride, + closed_loop=closed_loop, + dim=dim) + # make memory usage calculation only take into account the context window latents + comfy.context_windows.create_prepare_sampling_wrapper(model) + return io.NodeOutput(model) + +class WanContextWindowsManualNode(ContextWindowsManualNode): + @classmethod + def define_schema(cls) -> io.Schema: + schema = super().define_schema() + schema.node_id = "WanContextWindowsManual" + schema.display_name = "WAN Context Windows (Manual)" + schema.description = "Manually set context windows for WAN-like models (dim=2)." + schema.inputs = [ + io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window."), + io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window."), + io.Combo.Input("context_schedule", options=[ + comfy.context_windows.ContextSchedules.STATIC_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, + comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, + comfy.context_windows.ContextSchedules.BATCHED, + ], tooltip="The stride of the context window."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), + io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), + io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + ] + return schema + + @classmethod + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model: + context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 + context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2) + + +class ContextWindowsExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ContextWindowsManualNode, + WanContextWindowsManualNode, + ] + +def comfy_entrypoint(): + return ContextWindowsExtension() diff --git a/nodes.py b/nodes.py index 9448f9c1b..860a236aa 100644 --- a/nodes.py +++ b/nodes.py @@ -2320,6 +2320,7 @@ async def init_builtin_extra_nodes(): "nodes_camera_trajectory.py", "nodes_edit_model.py", "nodes_tcfg.py", + "nodes_context_windows.py", ] import_failed = [] From 72fd4d22b6a4fa11a3f737c9a633e7d635a42181 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 14 Aug 2025 13:03:21 -0700 Subject: [PATCH 25/71] av is an essential dependency. (#9341) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index bfb31a73f..551002b5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,11 +20,11 @@ tqdm psutil alembic SQLAlchemy +av>=14.2.0 #non essential dependencies: kornia>=0.7.1 spandrel soundfile -av>=14.2.0 pydantic~=2.0 pydantic-settings~=2.0 From 644b23ac0b92442b475e44397c62aa8de929d546 Mon Sep 17 00:00:00 2001 From: filtered <176114999+webfiltered@users.noreply.github.com> Date: Fri, 15 Aug 2025 07:36:53 +1000 Subject: [PATCH 26/71] Make custom node testing checkbox optional in issue templates (#9342) The checkbox for confirming custom node testing is now optional in both bug report and user support templates. This allows users to submit issues even if they haven't been able to test with custom nodes disabled, making the reporting process more accessible. --- .github/ISSUE_TEMPLATE/bug-report.yml | 2 +- .github/ISSUE_TEMPLATE/user-support.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 69ce998eb..3cf2717b7 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -22,7 +22,7 @@ body: description: Please confirm you have tried to reproduce the issue with all custom nodes disabled. options: - label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help) - required: true + required: false - type: textarea attributes: label: Expected Behavior diff --git a/.github/ISSUE_TEMPLATE/user-support.yml b/.github/ISSUE_TEMPLATE/user-support.yml index 50657d493..281661f92 100644 --- a/.github/ISSUE_TEMPLATE/user-support.yml +++ b/.github/ISSUE_TEMPLATE/user-support.yml @@ -18,7 +18,7 @@ body: description: Please confirm you have tried to reproduce the issue with all custom nodes disabled. options: - label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help) - required: true + required: false - type: textarea attributes: label: Your question From fa570cbf599657e73c636872616c0b1f8e74f692 Mon Sep 17 00:00:00 2001 From: Yoland Yan <4950057+yoland68@users.noreply.github.com> Date: Thu, 14 Aug 2025 16:44:22 -0700 Subject: [PATCH 27/71] Update CODEOWNERS (#9343) --- CODEOWNERS | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index c4acbf06e..c8acd66d5 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -5,20 +5,21 @@ # Inlined the team members for now. # Maintainers -*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne -/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne +*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill +/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill # Python web server -/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne -/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne -/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne +/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill +/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill +/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill # Node developers -/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne -/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne +/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill +/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill +/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill From deebee4ff6fd1b2713683d22e5e2e07170daa867 Mon Sep 17 00:00:00 2001 From: guill Date: Thu, 14 Aug 2025 18:46:55 -0700 Subject: [PATCH 28/71] Update default parameters for Moonvalley video nodes (#9290) * Update default parameters for Moonvalley video nodes - Changed default negative prompts to a more extensive list for both BaseMoonvalleyVideoNode and MoonvalleyVideo2VideoNode. - Updated default guidance scale values for both nodes to enhance prompt adherence. - Set a fixed default seed value for consistency in video generation. * no message * ruff fix --------- Co-authored-by: thorsten --- comfy_api_nodes/nodes_moonvalley.py | 128 ++++++++++++++++++++-------- 1 file changed, 91 insertions(+), 37 deletions(-) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 164ca3ea5..806a70e06 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,6 +1,5 @@ import logging from typing import Any, Callable, Optional, TypeVar -import random import torch from comfy_api_nodes.util.validation_utils import ( get_image_dimensions, @@ -208,20 +207,29 @@ def _get_video_dimensions(video: VideoInput) -> tuple[int, int]: def _validate_video_dimensions(width: int, height: int) -> None: """Validates video dimensions meet Moonvalley V2V requirements.""" supported_resolutions = { - (1920, 1080), (1080, 1920), (1152, 1152), - (1536, 1152), (1152, 1536) + (1920, 1080), + (1080, 1920), + (1152, 1152), + (1536, 1152), + (1152, 1536), } if (width, height) not in supported_resolutions: - supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)]) - raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") + supported_list = ", ".join( + [f"{w}x{h}" for w, h in sorted(supported_resolutions)] + ) + raise ValueError( + f"Resolution {width}x{height} not supported. Supported: {supported_list}" + ) def _validate_container_format(video: VideoInput) -> None: """Validates video container format is MP4.""" container_format = video.get_container_format() - if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']: - raise ValueError(f"Only MP4 container format supported. Got: {container_format}") + if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: + raise ValueError( + f"Only MP4 container format supported. Got: {container_format}" + ) def _validate_and_trim_duration(video: VideoInput) -> VideoInput: @@ -244,7 +252,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: return video - def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: """ Returns a new VideoInput object trimmed from the beginning to the specified duration, @@ -302,7 +309,9 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: # Calculate target frame count that's divisible by 16 fps = input_container.streams.video[0].average_rate estimated_frames = int(duration_sec * fps) - target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16 + target_frames = ( + estimated_frames // 16 + ) * 16 # Round down to nearest multiple of 16 if target_frames == 0: raise ValueError("Video too short: need at least 16 frames for Moonvalley") @@ -424,7 +433,7 @@ class BaseMoonvalleyVideoNode: MoonvalleyTextToVideoInferenceParams, "negative_prompt", multiline=True, - default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts", + default=" gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring", ), "resolution": ( IO.COMBO, @@ -441,12 +450,11 @@ class BaseMoonvalleyVideoNode: "tooltip": "Resolution of the output video", }, ), - # "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}), "prompt_adherence": model_field_to_node_input( IO.FLOAT, MoonvalleyTextToVideoInferenceParams, "guidance_scale", - default=7.0, + default=10.0, step=1, min=1, max=20, @@ -455,13 +463,12 @@ class BaseMoonvalleyVideoNode: IO.INT, MoonvalleyTextToVideoInferenceParams, "seed", - default=random.randint(0, 2**32 - 1), + default=9, min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", - control_after_generate=True, ), "steps": model_field_to_node_input( IO.INT, @@ -532,9 +539,11 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode): # Get MIME type from tensor - assuming PNG format for image tensors mime_type = "image/png" - image_url = (await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type - ))[0] + image_url = ( + await upload_images_to_comfyapi( + image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type + ) + )[0] request = MoonvalleyTextToVideoRequest( image_url=image_url, prompt_text=prompt, inference_params=inference_params @@ -570,17 +579,39 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): return { "required": { "prompt": model_field_to_node_input( - IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text", - multiline=True + IO.STRING, + MoonvalleyVideoToVideoRequest, + "prompt_text", + multiline=True, ), "negative_prompt": model_field_to_node_input( IO.STRING, MoonvalleyVideoToVideoInferenceParams, "negative_prompt", multiline=True, - default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts" + default=" gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring", + ), + "seed": model_field_to_node_input( + IO.INT, + MoonvalleyVideoToVideoInferenceParams, + "seed", + default=9, + min=0, + max=4294967295, + step=1, + display="number", + tooltip="Random seed value", + control_after_generate=False, + ), + "prompt_adherence": model_field_to_node_input( + IO.FLOAT, + MoonvalleyVideoToVideoInferenceParams, + "guidance_scale", + default=10.0, + step=1, + min=1, + max=20, ), - "seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True), }, "hidden": { "auth_token": "AUTH_TOKEN_COMFY_ORG", @@ -588,7 +619,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): "unique_id": "UNIQUE_ID", }, "optional": { - "video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}), + "video": ( + IO.VIDEO, + { + "default": "", + "multiline": False, + "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.", + }, + ), "control_type": ( ["Motion Transfer", "Pose Transfer"], {"default": "Motion Transfer"}, @@ -602,8 +640,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): "max": 100, "tooltip": "Only used if control_type is 'Motion Transfer'", }, - ) - } + ), + "image": model_field_to_node_input( + IO.IMAGE, + MoonvalleyTextToVideoRequest, + "image_url", + tooltip="The reference image used to generate the video", + ), + }, } RETURN_TYPES = ("VIDEO",) @@ -613,6 +657,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): video = kwargs.get("video") + image = kwargs.get("image", None) if not video: raise MoonvalleyApiError("video is required") @@ -620,8 +665,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): video_url = "" if video: validated_video = validate_video_to_video_input(video) - video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) + video_url = await upload_video_to_comfyapi( + validated_video, auth_kwargs=kwargs + ) + mime_type = "image/png" + if not image is None: + validate_input_image(image, with_frame_conditioning=True) + image_url = await upload_images_to_comfyapi( + image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type + ) control_type = kwargs.get("control_type") motion_intensity = kwargs.get("motion_intensity") @@ -631,12 +684,12 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): # Only include motion_intensity for Motion Transfer control_params = {} if control_type == "Motion Transfer" and motion_intensity is not None: - control_params['motion_intensity'] = motion_intensity + control_params["motion_intensity"] = motion_intensity - inference_params=MoonvalleyVideoToVideoInferenceParams( + inference_params = MoonvalleyVideoToVideoInferenceParams( negative_prompt=negative_prompt, seed=kwargs.get("seed"), - control_params=control_params + control_params=control_params, ) control = self.parseControlParameter(control_type) @@ -647,6 +700,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode): prompt_text=prompt, inference_params=inference_params, ) + request.image_url = image_url if not image is None else None initial_operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -694,15 +748,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode): validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = self.parseWidthHeightFromRes(kwargs.get("resolution")) - inference_params=MoonvalleyTextToVideoInferenceParams( - negative_prompt=negative_prompt, - steps=kwargs.get("steps"), - seed=kwargs.get("seed"), - guidance_scale=kwargs.get("prompt_adherence"), - num_frames=128, - width=width_height.get("width"), - height=width_height.get("height"), - ) + inference_params = MoonvalleyTextToVideoInferenceParams( + negative_prompt=negative_prompt, + steps=kwargs.get("steps"), + seed=kwargs.get("seed"), + guidance_scale=kwargs.get("prompt_adherence"), + num_frames=128, + width=width_height.get("width"), + height=width_height.get("height"), + ) request = MoonvalleyTextToVideoRequest( prompt_text=prompt, inference_params=inference_params ) From 5d65d6753b195d674ce16522d6c34f9a33f36269 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 15 Aug 2025 04:48:41 +0300 Subject: [PATCH 29/71] convert WAN nodes to V3 schema (#9201) --- comfy_extras/nodes_wan.py | 549 +++++++++++++++++++++----------------- 1 file changed, 298 insertions(+), 251 deletions(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index f80c83ba6..694a183f6 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -9,29 +9,35 @@ import comfy.clip_vision import json import numpy as np from typing import Tuple +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class WanImageToVideo: +class WanImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -51,32 +57,36 @@ class WanImageToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFunControlToVideo: +class WanFunControlToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "control_video": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFunControlToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -101,31 +111,34 @@ class WanFunControlToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class Wan22FunControlToVideo: +class Wan22FunControlToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"ref_image": ("IMAGE", ), - "control_video": ("IMAGE", ), - # "start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="Wan22FunControlToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("ref_image", optional=True), + io.Image.Input("control_video", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -158,32 +171,36 @@ class Wan22FunControlToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFirstLastFrameToVideo: +class WanFirstLastFrameToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ), - "clip_vision_end_image": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFirstLastFrameToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_end_image", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -224,62 +241,70 @@ class WanFirstLastFrameToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanFunInpaintToVideo: +class WanFunInpaintToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "end_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanFunInpaintToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.Image.Input("end_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None) -> io.NodeOutput: flfv = WanFirstLastFrameToVideo() - return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) + return flfv.execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output) -class WanVaceToVideo: +class WanVaceToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}), - }, - "optional": {"control_video": ("IMAGE", ), - "control_masks": ("MASK", ), - "reference_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanVaceToVideo", + category="conditioning/video_models", + is_experimental=True, + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("strength", default=1.0, min=0.0, max=1000.0, step=0.01), + io.Image.Input("control_video", optional=True), + io.Mask.Input("control_masks", optional=True), + io.Image.Input("reference_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT") - RETURN_NAMES = ("positive", "negative", "latent", "trim_latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - EXPERIMENTAL = True - - def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None) -> io.NodeOutput: latent_length = ((length - 1) // 4) + 1 if control_video is not None: control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -336,52 +361,59 @@ class WanVaceToVideo: latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device()) out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent, trim_latent) + return io.NodeOutput(positive, negative, out_latent, trim_latent) -class TrimVideoLatent: +class TrimVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}), - }} + def define_schema(cls): + return io.Schema( + node_id="TrimVideoLatent", + category="latent/video", + is_experimental=True, + inputs=[ + io.Latent.Input("samples"), + io.Int.Input("trim_amount", default=0, min=0, max=99999), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/video" - - EXPERIMENTAL = True - - def op(self, samples, trim_amount): + @classmethod + def execute(cls, samples, trim_amount) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = s1[:, :, trim_amount:] - return (samples_out,) + return io.NodeOutput(samples_out) -class WanCameraImageToVideo: +class WanCameraImageToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ), - "start_image": ("IMAGE", ), - "camera_conditions": ("WAN_CAMERA_EMBEDDING", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanCameraImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.WanCameraEmbedding.Input("camera_conditions", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent) @@ -404,29 +436,34 @@ class WanCameraImageToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class WanPhantomSubjectToVideo: +class WanPhantomSubjectToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"images": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanPhantomSubjectToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("images", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative_text"), + io.Conditioning.Output(display_name="negative_img_text"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, width, height, length, batch_size, images): + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, images) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) cond2 = negative if images is not None: @@ -442,7 +479,7 @@ class WanPhantomSubjectToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, cond2, negative, out_latent) + return io.NodeOutput(positive, cond2, negative, out_latent) def parse_json_tracks(tracks): """Parse JSON track data into a standardized format""" @@ -655,39 +692,41 @@ def patch_motion( return out_mask_full, out_feature_full -class WanTrackToVideo: +class WanTrackToVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "vae": ("VAE", ), - "tracks": ("STRING", {"multiline": True, "default": "[]"}), - "width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - "temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}), - "topk": ("INT", {"default": 2, "min": 1, "max": 10}), - "start_image": ("IMAGE", ), - }, - "optional": { - "clip_vision_output": ("CLIP_VISION_OUTPUT", ), - }} + def define_schema(cls): + return io.Schema( + node_id="WanPhantomSubjectToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.String.Input("tracks", multiline=True, default="[]"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Float.Input("temperature", default=220.0, min=1.0, max=1000.0, step=0.1), + io.Int.Input("topk", default=2, min=1, max=10), + io.Image.Input("start_image"), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") - RETURN_NAMES = ("positive", "negative", "latent") - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, positive, negative, vae, tracks, width, height, length, batch_size, - temperature, topk, start_image=None, clip_vision_output=None): + @classmethod + def execute(cls, positive, negative, vae, tracks, width, height, length, batch_size, + temperature, topk, start_image=None, clip_vision_output=None) -> io.NodeOutput: tracks_data = parse_json_tracks(tracks) if not tracks_data: - return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output) + return WanImageToVideo().execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) @@ -741,34 +780,36 @@ class WanTrackToVideo: out_latent = {} out_latent["samples"] = latent - return (positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent) -class Wan22ImageToVideoLatent: +class Wan22ImageToVideoLatent(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"vae": ("VAE", ), - "width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}), - "height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}), - "length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), - }, - "optional": {"start_image": ("IMAGE", ), - }} + def define_schema(cls): + return io.Schema( + node_id="Wan22ImageToVideoLatent", + category="conditioning/inpaint", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("width", default=1280, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=704, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Latent.Output(), + ], + ) - - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "conditioning/inpaint" - - def encode(self, vae, width, height, length, batch_size, start_image=None): + @classmethod + def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) if start_image is None: out_latent = {} out_latent["samples"] = latent - return (out_latent,) + return io.NodeOutput(out_latent) mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device()) @@ -783,19 +824,25 @@ class Wan22ImageToVideoLatent: latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask) out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1)) out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1)) - return (out_latent,) + return io.NodeOutput(out_latent) -NODE_CLASS_MAPPINGS = { - "WanTrackToVideo": WanTrackToVideo, - "WanImageToVideo": WanImageToVideo, - "WanFunControlToVideo": WanFunControlToVideo, - "Wan22FunControlToVideo": Wan22FunControlToVideo, - "WanFunInpaintToVideo": WanFunInpaintToVideo, - "WanFirstLastFrameToVideo": WanFirstLastFrameToVideo, - "WanVaceToVideo": WanVaceToVideo, - "TrimVideoLatent": TrimVideoLatent, - "WanCameraImageToVideo": WanCameraImageToVideo, - "WanPhantomSubjectToVideo": WanPhantomSubjectToVideo, - "Wan22ImageToVideoLatent": Wan22ImageToVideoLatent, -} +class WanExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanTrackToVideo, + WanImageToVideo, + WanFunControlToVideo, + Wan22FunControlToVideo, + WanFunInpaintToVideo, + WanFirstLastFrameToVideo, + WanVaceToVideo, + TrimVideoLatent, + WanCameraImageToVideo, + WanPhantomSubjectToVideo, + Wan22ImageToVideoLatent, + ] + +async def comfy_entrypoint() -> WanExtension: + return WanExtension() From ad19a069f68a19566632b9bda3e72f4eed8a22d8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:16:01 -0700 Subject: [PATCH 30/71] Make SLG nodes work on Qwen Image model. (#9345) --- comfy/ldm/qwen_image/model.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index c15ab8e40..99843f88d 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -356,6 +356,7 @@ class QwenImageTransformer2DModel(nn.Module): context, attention_mask=None, guidance: torch.Tensor = None, + transformer_options={}, **kwargs ): timestep = timesteps @@ -383,14 +384,26 @@ class QwenImageTransformer2DModel(nn.Module): else self.time_text_embed(timestep, guidance, hidden_states) ) - for block in self.transformer_blocks: - encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - encoder_hidden_states_mask=encoder_hidden_states_mask, - temb=temb, - image_rotary_emb=image_rotary_emb, - ) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"]) + return out + out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From f0d5d0111f1f78bc8ce5d1f3968f19e40cd2ce7b Mon Sep 17 00:00:00 2001 From: "Xiangxi Guo (Ryan)" Date: Thu, 14 Aug 2025 20:41:37 -0700 Subject: [PATCH 31/71] Avoid torch compile graphbreak for older pytorch versions (#9344) Turns out torch.compile has some gaps in context manager decorator syntax support. I've sent patches to fix that in PyTorch, but it won't be available for all the folks running older versions of PyTorch, hence this trivial patch. --- comfy/ops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index be312d714..2be35f76a 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -41,9 +41,11 @@ try: SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - @sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True) def scaled_dot_product_attention(q, k, v, *args, **kwargs): - return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + # Use this (rather than the decorator syntax) to eliminate graph + # break for pytorch < 2.9 + with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) except (ModuleNotFoundError, TypeError): logging.warning("Could not set sdpa backend priority.") From 4e5c230f6a957962961794c07f02be748076c771 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 14 Aug 2025 20:44:02 -0700 Subject: [PATCH 32/71] Fix last commit not working on older pytorch. (#9346) --- comfy/ops.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 2be35f76a..18e7db705 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -32,20 +32,21 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs): try: if torch.cuda.is_available(): from torch.nn.attention import SDPBackend, sdpa_kernel + import inspect + if "set_priority" in inspect.signature(sdpa_kernel).parameters: + SDPA_BACKEND_PRIORITY = [ + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.MATH, + ] - SDPA_BACKEND_PRIORITY = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.MATH, - ] + SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) - - def scaled_dot_product_attention(q, k, v, *args, **kwargs): - # Use this (rather than the decorator syntax) to eliminate graph - # break for pytorch < 2.9 - with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): - return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + def scaled_dot_product_attention(q, k, v, *args, **kwargs): + with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) + else: + logging.warning("Torch version too old to set sdpa backend priority.") except (ModuleNotFoundError, TypeError): logging.warning("Could not set sdpa backend priority.") From e08ecfbd8a9deda8939b14d7f1ff7d7139f1a4ed Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 14 Aug 2025 21:22:26 -0700 Subject: [PATCH 33/71] Add warning when using old pytorch. (#9347) --- comfy/rmsnorm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 66ae8321d..555542a46 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -1,6 +1,7 @@ import torch import comfy.model_management import numbers +import logging RMSNorm = None @@ -9,6 +10,7 @@ try: RMSNorm = torch.nn.RMSNorm except: rms_norm_torch = None + logging.warning("Please update pytorch to use native RMSNorm") def rms_norm(x, weight=None, eps=1e-6): From 027c63f63a7f5f380a4df1057c548410b0a87606 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 15 Aug 2025 21:57:47 +0300 Subject: [PATCH 34/71] fix(OpenAIGPTImage1): set correct MIME type for multipart uploads to OpenAI edits (#9348) --- comfy_api_nodes/nodes_openai.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index ab3c5363b..cbff2b2d2 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -464,8 +464,6 @@ class OpenAIGPTImage1(ComfyNodeABC): path = "/proxy/openai/images/generations" content_type = "application/json" request_class = OpenAIImageGenerationRequest - img_binaries = [] - mask_binary = None files = [] if image is not None: @@ -484,14 +482,11 @@ class OpenAIGPTImage1(ComfyNodeABC): img_byte_arr = io.BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) - img_binary = img_byte_arr - img_binary.name = f"image_{i}.png" - img_binaries.append(img_binary) if batch_size == 1: - files.append(("image", img_binary)) + files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png"))) else: - files.append(("image[]", img_binary)) + files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png"))) if mask is not None: if image is None: @@ -511,9 +506,7 @@ class OpenAIGPTImage1(ComfyNodeABC): mask_img_byte_arr = io.BytesIO() mask_img.save(mask_img_byte_arr, format="PNG") mask_img_byte_arr.seek(0) - mask_binary = mask_img_byte_arr - mask_binary.name = "mask.png" - files.append(("mask", mask_binary)) + files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) # Build the operation operation = SynchronousOperation( From c308a8840aebf06649364e8e175862250a2d8823 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 15 Aug 2025 12:50:39 -0700 Subject: [PATCH 35/71] Add FluxKontextMultiReferenceLatentMethod node. (#9356) This node is only useful if someone trains the kontext model to properly use multiple reference images via the index method. The default is the offset method which feeds the multiple images like if they were stitched together as one. This method works with the current flux kontext model. --- comfy/ldm/flux/model.py | 24 ++++++++++++++++-------- comfy/model_base.py | 4 ++++ comfy_extras/nodes_flux.py | 19 +++++++++++++++++++ 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 8f4d99f54..c4de82795 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -224,19 +224,27 @@ class Flux(nn.Module): if ref_latents is not None: h = 0 w = 0 + index = 0 + index_ref_method = kwargs.get("ref_latents_method", "offset") == "index" for ref in ref_latents: - h_offset = 0 - w_offset = 0 - if ref.shape[-2] + h > ref.shape[-1] + w: - w_offset = w + if index_ref_method: + index += 1 + h_offset = 0 + w_offset = 0 else: - h_offset = h + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - h = max(h, ref.shape[-2] + h_offset) - w = max(w, ref.shape[-1] + w_offset) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) diff --git a/comfy/model_base.py b/comfy/model_base.py index cde61df7c..bf874b875 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -890,6 +890,10 @@ class Flux(BaseModel): for lat in ref_latents: latents.append(self.process_latent_in(lat)) out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_latents_method = kwargs.get("reference_latents_method", None) + if ref_latents_method is not None: + out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) return out def extra_conds_shapes(self, **kwargs): diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 8a8a17698..c8db75bb3 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -100,9 +100,28 @@ class FluxKontextImageScale: return (image, ) +class FluxKontextMultiReferenceLatentMethod: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "conditioning": ("CONDITIONING", ), + "reference_latents_method": (("offset", "index"), ), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + EXPERIMENTAL = True + + CATEGORY = "advanced/conditioning/flux" + + def append(self, conditioning, reference_latents_method): + c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method}) + return (c, ) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeFlux": CLIPTextEncodeFlux, "FluxGuidance": FluxGuidance, "FluxDisableGuidance": FluxDisableGuidance, "FluxKontextImageScale": FluxKontextImageScale, + "FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod, } From 1702e6df16b0a52e147f19e3d5c5548c25a64339 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 15 Aug 2025 14:29:58 -0700 Subject: [PATCH 36/71] Implement wan2.2 camera model. (#9357) Use the old WanCameraImageToVideo node. --- comfy/ldm/wan/model.py | 7 ++++++- comfy/model_detection.py | 5 ++++- comfy/supported_models.py | 14 +++++++++++++- comfy_extras/nodes_wan.py | 7 +++++-- 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4e2d99566..9d3741be3 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -768,7 +768,12 @@ class CameraWanModel(WanModel): operations=None, ): - super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) + if model_type == 'camera': + model_type = 'i2v' + else: + model_type = 't2v' + + super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) operation_settings = {"operations": operations, "device": device, "dtype": dtype} self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8acc51e20..2bec0541e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -364,7 +364,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys: - dit_config["model_type"] = "camera" + if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "camera" + else: + dit_config["model_type"] = "camera_2.2" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 156ff9e26..7ed6dfd69 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1046,6 +1046,18 @@ class WAN21_Camera(WAN21_T2V): def get_model(self, state_dict, prefix="", device=None): out = model_base.WAN21_Camera(self, image_to_video=False, device=device) return out + +class WAN22_Camera(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "camera_2.2", + "in_dim": 36, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_Camera(self, image_to_video=False, device=device) + return out + class WAN21_Vace(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1260,6 +1272,6 @@ class QwenImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 694a183f6..83a990688 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -422,9 +422,12 @@ class WanCameraImageToVideo(io.ComfyNode): start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) concat_latent_image = vae.encode(start_image[:, :, :, :3]) concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] + mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1])) + mask[:, :, :start_image.shape[0] + 3] = 0.0 + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask}) if camera_conditions is not None: positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions}) From ed2e33c69a291094c4fcc13d8426c49844a6363c Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Fri, 15 Aug 2025 20:32:58 -0700 Subject: [PATCH 37/71] bump frontend version to 1.25.8 (#9361) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 551002b5b..2ae44ebe1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.24.4 +comfyui-frontend-package==1.25.8 comfyui-workflow-templates==0.1.59 comfyui-embedded-docs==0.2.6 torch From 20a84166d0d37dd6833caa6cadf3bfac8c241b48 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Sat, 16 Aug 2025 02:07:12 -0400 Subject: [PATCH 38/71] record audio node (#8716) * record audio node * sf --- comfy_extras/nodes_audio.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index a90b31779..3b23f65d8 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -346,6 +346,24 @@ class LoadAudio: return "Invalid audio file: {}".format(audio) return True +class RecordAudio: + @classmethod + def INPUT_TYPES(s): + return {"required": {"audio": ("AUDIO_RECORD", {})}} + + CATEGORY = "audio" + + RETURN_TYPES = ("AUDIO", ) + FUNCTION = "load" + + def load(self, audio): + audio_path = folder_paths.get_annotated_filepath(audio) + + waveform, sample_rate = torchaudio.load(audio_path) + audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} + return (audio, ) + + NODE_CLASS_MAPPINGS = { "EmptyLatentAudio": EmptyLatentAudio, "VAEEncodeAudio": VAEEncodeAudio, @@ -356,6 +374,7 @@ NODE_CLASS_MAPPINGS = { "LoadAudio": LoadAudio, "PreviewAudio": PreviewAudio, "ConditioningStableAudio": ConditioningStableAudio, + "RecordAudio": RecordAudio, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -367,4 +386,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SaveAudio": "Save Audio (FLAC)", "SaveAudioMP3": "Save Audio (MP3)", "SaveAudioOpus": "Save Audio (Opus)", + "RecordAudio": "Record Audio", } From 0f2b8525bcafe213e8421a49564a90f926e81f2e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 16 Aug 2025 14:51:28 -0700 Subject: [PATCH 39/71] Qwen image model refactor. (#9375) --- comfy/ldm/qwen_image/model.py | 36 +++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 99843f88d..40d8fd979 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -333,21 +333,25 @@ class QwenImageTransformer2DModel(nn.Module): self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) self.gradient_checkpointing = False - def pos_embeds(self, x, context): + def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape patch_size = self.patch_size + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) + hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) h_len = ((h + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + h_offset = ((h_offset + (patch_size // 2)) // patch_size) + w_offset = ((w_offset + (patch_size // 2)) // patch_size) - txt_start = round(max(h_len, w_len)) - txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3) - ids = torch.cat((txt_ids, img_ids), dim=1) - return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 1] + index + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape def forward( self, @@ -363,13 +367,13 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states = context encoder_hidden_states_mask = attention_mask - image_rotary_emb = self.pos_embeds(x, context) + hidden_states, img_ids, orig_shape = self.process_img(x) + num_embeds = hidden_states.shape[1] - hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) - orig_shape = hidden_states.shape - hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) - hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) - hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size))) + txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states) @@ -408,6 +412,6 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] From ed43784b0d04e5b8e8ff2c057fa84b9c5132aaf2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 17 Aug 2025 13:45:39 -0700 Subject: [PATCH 40/71] WIP Qwen edit model: The diffusion model part. (#9383) --- comfy/ldm/qwen_image/model.py | 26 ++++++++++++++++++++++++++ comfy/model_base.py | 10 ++++++++++ 2 files changed, 36 insertions(+) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 40d8fd979..a3c726299 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -360,6 +360,7 @@ class QwenImageTransformer2DModel(nn.Module): context, attention_mask=None, guidance: torch.Tensor = None, + ref_latents=None, transformer_options={}, **kwargs ): @@ -370,6 +371,31 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] + if ref_latents is not None: + h = 0 + w = 0 + index = 0 + index_ref_method = kwargs.get("ref_latents_method", "index") == "index" + for ref in ref_latents: + if index_ref_method: + index += 1 + h_offset = 0 + w_offset = 0 + else: + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) + + kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + hidden_states = torch.cat([hidden_states, kontext], dim=1) + img_ids = torch.cat([img_ids, kontext_ids], dim=1) + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size))) txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) diff --git a/comfy/model_base.py b/comfy/model_base.py index bf874b875..15bd7abef 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1331,4 +1331,14 @@ class QwenImage(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_latents_method = kwargs.get("reference_latents_method", None) + if ref_latents_method is not None: + out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) return out From d4e353a94ec5a8cb15ed151990a9518b890e5d4f Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Mon, 18 Aug 2025 05:38:40 +0800 Subject: [PATCH 41/71] Update template to 0.1.60 (#9377) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2ae44ebe1..72a700028 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.8 -comfyui-workflow-templates==0.1.59 +comfyui-workflow-templates==0.1.60 comfyui-embedded-docs==0.2.6 torch torchsde From 7f3b9b16c6636cb1201213574892d33c2a35e4ba Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 17 Aug 2025 15:54:07 -0700 Subject: [PATCH 42/71] Make step index detection much more robust (#9392) --- comfy/context_windows.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 928b111df..041f380f9 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -164,8 +164,11 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - indexes = torch.where(model_options["transformer_options"]["sample_sigmas"] == timestep[0]) - self._step = int(indexes[0]) + mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001) + matches = torch.nonzero(mask) + if torch.numel(matches) == 0: + raise Exception("No sample_sigmas matched current timestep; something went wrong.") + self._step = int(matches[0].item()) def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: full_length = x_in.size(self.dim) # TODO: choose dim based on model From da2efeaec6609265051165bfb413a2a4c84cf4bb Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sun, 17 Aug 2025 20:21:02 -0700 Subject: [PATCH 43/71] Bump frontend to 1.25.9 (#9394) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 72a700028..c7a5c47ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.25.8 +comfyui-frontend-package==1.25.9 comfyui-workflow-templates==0.1.60 comfyui-embedded-docs==0.2.6 torch From bd2ab73976a4e245db3e057795328c89bfd98a88 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 18 Aug 2025 10:26:55 +0300 Subject: [PATCH 44/71] fix(WAN-nodes): invalid nodeid for WanTrackToVideo (#9396) --- comfy_extras/nodes_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 83a990688..0fff02f76 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -699,7 +699,7 @@ class WanTrackToVideo(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="WanPhantomSubjectToVideo", + node_id="WanTrackToVideo", category="conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), From 4977f203fa8e9e3ab22884c8ace8f9b540d48952 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 18 Aug 2025 19:38:34 -0700 Subject: [PATCH 45/71] P2 of qwen edit model. (#9412) * P2 of qwen edit model. * Typo. * Fix normal qwen. * Fix. * Make the TextEncodeQwenImageEdit also set the ref latent. If you don't want it to set the ref latent and want to use the ReferenceLatent node with your custom latent instead just disconnect the VAE. --- comfy/clip_model.py | 2 +- comfy/model_base.py | 8 + comfy/sd1_clip.py | 11 +- comfy/text_encoders/bert.py | 2 +- comfy/text_encoders/llama.py | 43 ++- comfy/text_encoders/qwen_image.py | 20 +- comfy/text_encoders/qwen_vl.py | 428 ++++++++++++++++++++++++++++++ comfy/text_encoders/t5.py | 2 +- comfy_extras/nodes_qwen.py | 63 +++++ nodes.py | 1 + 10 files changed, 565 insertions(+), 15 deletions(-) create mode 100644 comfy/text_encoders/qwen_vl.py create mode 100644 comfy_extras/nodes_qwen.py diff --git a/comfy/clip_model.py b/comfy/clip_model.py index c8294d483..7e47d8a55 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module): self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) - def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): + def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]): if embeds is not None: x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device) else: diff --git a/comfy/model_base.py b/comfy/model_base.py index 15bd7abef..6c861b15e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1325,6 +1325,7 @@ class Omnigen2(BaseModel): class QwenImage(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel) + self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1342,3 +1343,10 @@ class QwenImage(BaseModel): if ref_latents_method is not None: out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + return out diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ade340fd1..1e8adbe69 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -204,17 +204,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32) index = 0 pad_extra = 0 + embeds_info = [] for o in other_embeds: emb = o[1] if torch.is_tensor(emb): emb = {"type": "embedding", "data": emb} + extra = None emb_type = emb.get("type", None) if emb_type == "embedding": emb = emb.get("data", None) else: if hasattr(self.transformer, "preprocess_embed"): - emb = self.transformer.preprocess_embed(emb, device=device) + emb, extra = self.transformer.preprocess_embed(emb, device=device) else: emb = None @@ -229,6 +231,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1) attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:] index += emb_shape - 1 + embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra}) else: index += -1 pad_extra += emb_shape @@ -243,11 +246,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): attention_masks.append(attention_mask) num_tokens.append(sum(attention_mask)) - return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens + return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info def forward(self, tokens): device = self.transformer.get_input_embeddings().weight.device - embeds, attention_mask, num_tokens = self.process_tokens(tokens, device) + embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device) attention_mask_model = None if self.enable_attention_masks: @@ -258,7 +261,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): else: intermediate_output = self.layer_idx - outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) + outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info) if self.layer == "last": z = outputs[0].float() diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py index 551b03162..ed4638a9a 100644 --- a/comfy/text_encoders/bert.py +++ b/comfy/text_encoders/bert.py @@ -116,7 +116,7 @@ class BertModel_(torch.nn.Module): self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations) self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations) - def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype) mask = None if attention_mask is not None: diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 1da6a0c94..9d90d5a61 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -2,12 +2,14 @@ import torch import torch.nn as nn from dataclasses import dataclass from typing import Optional, Any +import math from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit import comfy.model_management +from . import qwen_vl @dataclass class Llama2Config: @@ -100,12 +102,10 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def precompute_freqs_cis(head_dim, seq_len, theta, device=None): +def precompute_freqs_cis(head_dim, position_ids, theta, device=None): theta_numerator = torch.arange(0, head_dim, 2, device=device).float() inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) - position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0) - inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) @@ -277,7 +277,7 @@ class Llama2_(nn.Module): self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): if embeds is not None: x = embeds else: @@ -286,8 +286,11 @@ class Llama2_(nn.Module): if self.normalize_in: x *= self.config.hidden_size ** 0.5 + if position_ids is None: + position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0) + freqs_cis = precompute_freqs_cis(self.config.head_dim, - x.shape[1], + position_ids, self.config.rope_theta, device=x.device) @@ -372,8 +375,38 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module): self.num_layers = config.num_hidden_layers self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations) self.dtype = dtype + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image, grid = qwen_vl.process_qwen2vl_images(embed["data"]) + return self.visual(image.to(device, dtype=torch.float32), grid), grid + return None, None + + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): + grid = None + for e in embeds_info: + if e.get("type") == "image": + grid = e.get("extra", None) + position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) + start = e.get("index") + position_ids[:, :start] = torch.arange(0, start, device=embeds.device) + end = e.get("size") + start + len_max = int(grid.max()) // 2 + start_next = len_max + start + position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device) + position_ids[0, start:end] = start + max_d = int(grid[0][1]) // 2 + position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] + max_d = int(grid[0][2]) // 2 + position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start] + + if grid is None: + position_ids = None + + return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids) + class Gemma2_2B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index ce5c98097..f07318d6c 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -15,13 +15,27 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer) self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image \\(color, shape, size, texture, objects, background\\), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" - def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): if llama_template is None: - llama_text = self.llama_template.format(text) + if len(images) > 0: + llama_text = self.llama_template_images.format(text) + else: + llama_text = self.llama_template.format(text) else: llama_text = llama_template.format(text) - return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs) + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs) + key_name = next(iter(tokens)) + embed_count = 0 + qwen_tokens = tokens[key_name] + for r in qwen_tokens: + for i in range(len(r)): + if r[i][0] == 151655: + if len(images) > embed_count: + r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + return tokens class Qwen25_7BVLIModel(sd1_clip.SDClipModel): diff --git a/comfy/text_encoders/qwen_vl.py b/comfy/text_encoders/qwen_vl.py new file mode 100644 index 000000000..3b18ce730 --- /dev/null +++ b/comfy/text_encoders/qwen_vl.py @@ -0,0 +1,428 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Tuple +import math +from comfy.ldm.modules.attention import optimized_attention_for_device + + +def process_qwen2vl_images( + images: torch.Tensor, + min_pixels: int = 3136, + max_pixels: int = 12845056, + patch_size: int = 14, + temporal_patch_size: int = 2, + merge_size: int = 2, + image_mean: list = None, + image_std: list = None, +): + if image_mean is None: + image_mean = [0.48145466, 0.4578275, 0.40821073] + if image_std is None: + image_std = [0.26862954, 0.26130258, 0.27577711] + + batch_size, height, width, channels = images.shape + device = images.device + # dtype = images.dtype + + images = images.permute(0, 3, 1, 2) + + grid_thw_list = [] + img = images[0] + + factor = patch_size * merge_size + + h_bar = round(height / factor) * factor + w_bar = round(width / factor) * factor + + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, math.floor(height / beta / factor) * factor) + w_bar = max(factor, math.floor(width / beta / factor) * factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = math.ceil(height * beta / factor) * factor + w_bar = math.ceil(width * beta / factor) * factor + + img_resized = F.interpolate( + img.unsqueeze(0), + size=(h_bar, w_bar), + mode='bilinear', + align_corners=False + ).squeeze(0) + + normalized = img_resized.clone() + for c in range(3): + normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c] + + grid_h = h_bar // patch_size + grid_w = w_bar // patch_size + grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long) + + pixel_values = normalized + grid_thw_list.append(grid_thw) + image_grid_thw = torch.stack(grid_thw_list) + + grid_t = 1 + channel = pixel_values.shape[0] + pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1) + + patches = pixel_values.reshape( + grid_t, + temporal_patch_size, + channel, + grid_h // merge_size, + merge_size, + patch_size, + grid_w // merge_size, + merge_size, + patch_size, + ) + + patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + flatten_patches = patches.reshape( + grid_t * grid_h * grid_w, + channel * temporal_patch_size * patch_size * patch_size + ) + + return flatten_patches, image_grid_thw + + +class VisionPatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 3584, + device=None, + dtype=None, + ops=None, + ): + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = ops.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + device=device, + dtype=dtype + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.view( + -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size + ) + hidden_states = self.proj(hidden_states) + return hidden_states.view(-1, self.embed_dim) + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb_vision(q, k, cos, sin): + cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float() + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0): + super().__init__() + self.dim = dim + self.theta = theta + + def forward(self, seqlen: int, device) -> torch.Tensor: + inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim)) + seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.outer(seq, inv_freq) + return freqs + + +class PatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size ** 2) + self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype) + self.mlp = nn.Sequential( + ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype), + nn.GELU(), + ops.Linear(self.hidden_size, dim, device=device, dtype=dtype), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x).reshape(-1, self.hidden_size) + x = self.mlp(x) + return x + + +class VisionAttention(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scaling = self.head_dim ** -0.5 + + self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype) + self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cu_seqlens=None, + optimized_attention=None, + ) -> torch.Tensor: + if hidden_states.dim() == 2: + seq_length, _ = hidden_states.shape + batch_size = 1 + hidden_states = hidden_states.unsqueeze(0) + else: + batch_size, seq_length, _ = hidden_states.shape + + qkv = self.qkv(hidden_states) + qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim) + query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin) + + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + splits = [ + torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + ] + + attn_outputs = [ + optimized_attention(q, k, v, self.num_heads, skip_reshape=True) + for q, k, v in zip(*splits) + ] + attn_output = torch.cat(attn_outputs, dim=1) + attn_output = attn_output.reshape(seq_length, -1) + attn_output = self.proj(attn_output) + + return attn_output + + +class VisionMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None): + super().__init__() + self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) + self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype) + self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype) + self.act_fn = nn.SiLU() + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class VisionBlock(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None): + super().__init__() + self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) + self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype) + self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops) + self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + cu_seqlens=None, + optimized_attention=None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen2VLVisionTransformer(nn.Module): + def __init__( + self, + hidden_size: int = 3584, + output_hidden_size: int = 3584, + intermediate_size: int = 3420, + num_heads: int = 16, + num_layers: int = 32, + patch_size: int = 14, + temporal_patch_size: int = 2, + spatial_merge_size: int = 2, + window_size: int = 112, + device=None, + dtype=None, + ops=None + ): + super().__init__() + self.hidden_size = hidden_size + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.window_size = window_size + self.fullatt_block_indexes = [7, 15, 23, 31] + + self.patch_embed = VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=3, + embed_dim=hidden_size, + device=device, + dtype=dtype, + ops=ops, + ) + + head_dim = hidden_size // num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops) + for _ in range(num_layers) + ]) + + self.merger = PatchMerger( + dim=output_hidden_size, + context_dim=hidden_size, + spatial_merge_size=spatial_merge_size, + device=device, + dtype=dtype, + ops=ops, + ) + + def get_window_index(self, grid_thw): + window_index = [] + cu_window_seqlens = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size, + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size, + ) + + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + + window_index = torch.cat(window_index, dim=0) + return window_index, cu_window_seqlens + + def get_position_embeddings(self, grid_thw, device): + pos_ids = [] + + for t, h, w in grid_thw: + hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten() + + wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1) + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten() + + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device) + return rotary_pos_emb_full[pos_ids].flatten(1) + + def forward( + self, + pixel_values: torch.Tensor, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True) + + hidden_states = self.patch_embed(pixel_values) + + window_index, cu_window_seqlens = self.get_window_index(image_grid_thw) + cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device) + + seq_len, _ = hidden_states.size() + spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + + position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1) + position_embeddings = position_embeddings[window_index, :, :] + position_embeddings = position_embeddings.reshape(seq_len, -1) + position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1) + position_embeddings = (position_embeddings.cos(), position_embeddings.sin()) + + cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum( + dim=0, + dtype=torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for i, block in enumerate(self.blocks): + if i in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention) + + hidden_states = self.merger(hidden_states) + return hidden_states diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py index 36bf35309..e8588992a 100644 --- a/comfy/text_encoders/t5.py +++ b/comfy/text_encoders/t5.py @@ -199,7 +199,7 @@ class T5Stack(torch.nn.Module): self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) # self.dropout = nn.Dropout(config.dropout_rate) - def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): + def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py new file mode 100644 index 000000000..b5088fae2 --- /dev/null +++ b/comfy_extras/nodes_qwen.py @@ -0,0 +1,63 @@ +import node_helpers +import comfy.utils + +PREFERRED_QWENIMAGE_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + + +class TextEncodeQwenImageEdit: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}), + }, + "optional": {"vae": ("VAE", ), + "image": ("IMAGE", ),}} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, prompt, vae=None, image=None): + ref_latent = None + if image is None: + images = [] + else: + images = [image] + if vae is not None: + width = image.shape[2] + height = image.shape[1] + aspect_ratio = width / height + _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS) + image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) + ref_latent = vae.encode(image[:, :, :, :3]) + + tokens = clip.tokenize(prompt, images=images) + conditioning = clip.encode_from_tokens_scheduled(tokens) + if ref_latent is not None: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True) + return (conditioning, ) + + +NODE_CLASS_MAPPINGS = { + "TextEncodeQwenImageEdit": TextEncodeQwenImageEdit, +} diff --git a/nodes.py b/nodes.py index 860a236aa..b3fa9c51a 100644 --- a/nodes.py +++ b/nodes.py @@ -2321,6 +2321,7 @@ async def init_builtin_extra_nodes(): "nodes_edit_model.py", "nodes_tcfg.py", "nodes_context_windows.py", + "nodes_qwen.py", ] import_failed = [] From 36b5127fd3eee8eaf95ff7296a61269ed56d53c0 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:28:07 +0300 Subject: [PATCH 46/71] api_nodes: add kling-v2-1 and v2-1-master (#9257) --- comfy_api_nodes/apis/__init__.py | 3 +++ comfy_api_nodes/nodes_kling.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index 54298e8a9..c6f91e9d6 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -1315,6 +1315,7 @@ class KlingTaskStatus(str, Enum): class KlingTextToVideoModelName(str, Enum): kling_v1 = 'kling-v1' kling_v1_6 = 'kling-v1-6' + kling_v2_1_master = 'kling-v2-1-master' class KlingVideoGenAspectRatio(str, Enum): @@ -1347,6 +1348,8 @@ class KlingVideoGenModelName(str, Enum): kling_v1_5 = 'kling-v1-5' kling_v1_6 = 'kling-v1-6' kling_v2_master = 'kling-v2-master' + kling_v2_1 = 'kling-v2-1' + kling_v2_1_master = 'kling-v2-1-master' class KlingVideoResult(BaseModel): diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9d483bb0e..9fa390985 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -421,6 +421,8 @@ class KlingTextToVideoNode(KlingNodeBase): "pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"), "standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"), "standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"), + "pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"), + "pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"), } @classmethod From f16a70ba670e11de549af188663a87c77c5bc0c2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:28:27 +0300 Subject: [PATCH 47/71] api_nodes: add MinimaxHailuoVideoNode node (#9262) --- comfy_api_nodes/apis/__init__.py | 13 ++- comfy_api_nodes/nodes_minimax.py | 185 ++++++++++++++++++++++++++++++- 2 files changed, 191 insertions(+), 7 deletions(-) diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index c6f91e9d6..7a09df55b 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -1623,13 +1623,14 @@ class MinimaxTaskResultResponse(BaseModel): task_id: str = Field(..., description='The task ID being queried.') -class Model(str, Enum): +class MiniMaxModel(str, Enum): T2V_01_Director = 'T2V-01-Director' I2V_01_Director = 'I2V-01-Director' S2V_01 = 'S2V-01' I2V_01 = 'I2V-01' I2V_01_live = 'I2V-01-live' T2V_01 = 'T2V-01' + Hailuo_02 = 'MiniMax-Hailuo-02' class SubjectReferenceItem(BaseModel): @@ -1651,7 +1652,7 @@ class MinimaxVideoGenerationRequest(BaseModel): None, description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.', ) - model: Model = Field( + model: MiniMaxModel = Field( ..., description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01', ) @@ -1668,6 +1669,14 @@ class MinimaxVideoGenerationRequest(BaseModel): None, description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.', ) + duration: Optional[int] = Field( + None, + description="The length of the output video in seconds." + ) + resolution: Optional[str] = Field( + None, + description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels." + ) class MinimaxVideoGenerationResponse(BaseModel): diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 58d2ed90c..bb3c9e710 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -1,3 +1,4 @@ +from inspect import cleandoc from typing import Union import logging import torch @@ -10,7 +11,7 @@ from comfy_api_nodes.apis import ( MinimaxFileRetrieveResponse, MinimaxTaskResultResponse, SubjectReferenceItem, - Model + MiniMaxModel ) from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -84,7 +85,6 @@ class MinimaxTextToVideoNode: FUNCTION = "generate_video" CATEGORY = "api node/video/MiniMax" API_NODE = True - OUTPUT_NODE = True async def generate_video( self, @@ -121,7 +121,7 @@ class MinimaxTextToVideoNode: response_model=MinimaxVideoGenerationResponse, ), request=MinimaxVideoGenerationRequest( - model=Model(model), + model=MiniMaxModel(model), prompt=prompt_text, callback_url=None, first_frame_image=image_url, @@ -251,7 +251,6 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode): FUNCTION = "generate_video" CATEGORY = "api node/video/MiniMax" API_NODE = True - OUTPUT_NODE = True class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): @@ -313,7 +312,181 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): FUNCTION = "generate_video" CATEGORY = "api node/video/MiniMax" API_NODE = True - OUTPUT_NODE = True + + +class MinimaxHailuoVideoNode: + """Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "prompt_text": ( + "STRING", + { + "multiline": True, + "default": "", + "tooltip": "Text prompt to guide the video generation.", + }, + ), + }, + "optional": { + "seed": ( + IO.INT, + { + "default": 0, + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + "control_after_generate": True, + "tooltip": "The random seed used for creating the noise.", + }, + ), + "first_frame_image": ( + IO.IMAGE, + { + "tooltip": "Optional image to use as the first frame to generate a video." + }, + ), + "prompt_optimizer": ( + IO.BOOLEAN, + { + "tooltip": "Optimize prompt to improve generation quality when needed.", + "default": True, + }, + ), + "duration": ( + IO.COMBO, + { + "tooltip": "The length of the output video in seconds.", + "default": 6, + "options": [6, 10], + }, + ), + "resolution": ( + IO.COMBO, + { + "tooltip": "The dimensions of the video display. " + "1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels.", + "default": "768P", + "options": ["768P", "1080P"], + }, + ), + }, + "hidden": { + "auth_token": "AUTH_TOKEN_COMFY_ORG", + "comfy_api_key": "API_KEY_COMFY_ORG", + "unique_id": "UNIQUE_ID", + }, + } + + RETURN_TYPES = ("VIDEO",) + DESCRIPTION = cleandoc(__doc__ or "") + FUNCTION = "generate_video" + CATEGORY = "api node/video/MiniMax" + API_NODE = True + + async def generate_video( + self, + prompt_text, + seed=0, + first_frame_image: torch.Tensor=None, # used for ImageToVideo + prompt_optimizer=True, + duration=6, + resolution="768P", + model="MiniMax-Hailuo-02", + unique_id: Union[str, None]=None, + **kwargs, + ): + if first_frame_image is None: + validate_string(prompt_text, field_name="prompt_text") + + if model == "MiniMax-Hailuo-02" and resolution.upper() == "1080P" and duration != 6: + raise Exception( + "When model is MiniMax-Hailuo-02 and resolution is 1080P, duration is limited to 6 seconds." + ) + + # upload image, if passed in + image_url = None + if first_frame_image is not None: + image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=kwargs))[0] + + video_generate_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/minimax/video_generation", + method=HttpMethod.POST, + request_model=MinimaxVideoGenerationRequest, + response_model=MinimaxVideoGenerationResponse, + ), + request=MinimaxVideoGenerationRequest( + model=MiniMaxModel(model), + prompt=prompt_text, + callback_url=None, + first_frame_image=image_url, + prompt_optimizer=prompt_optimizer, + duration=duration, + resolution=resolution, + ), + auth_kwargs=kwargs, + ) + response = await video_generate_operation.execute() + + task_id = response.task_id + if not task_id: + raise Exception(f"MiniMax generation failed: {response.base_resp}") + + average_duration = 120 if resolution == "768P" else 240 + video_generate_operation = PollingOperation( + poll_endpoint=ApiEndpoint( + path="/proxy/minimax/query/video_generation", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=MinimaxTaskResultResponse, + query_params={"task_id": task_id}, + ), + completed_statuses=["Success"], + failed_statuses=["Fail"], + status_extractor=lambda x: x.status.value, + estimated_duration=average_duration, + node_id=unique_id, + auth_kwargs=kwargs, + ) + task_result = await video_generate_operation.execute() + + file_id = task_result.file_id + if file_id is None: + raise Exception("Request was not successful. Missing file ID.") + file_retrieve_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/minimax/files/retrieve", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=MinimaxFileRetrieveResponse, + query_params={"file_id": int(file_id)}, + ), + request=EmptyRequest(), + auth_kwargs=kwargs, + ) + file_result = await file_retrieve_operation.execute() + + file_url = file_result.file.download_url + if file_url is None: + raise Exception( + f"No video was found in the response. Full response: {file_result.model_dump()}" + ) + logging.info(f"Generated video URL: {file_url}") + if unique_id: + if hasattr(file_result.file, "backup_download_url"): + message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" + else: + message = f"Result URL: {file_url}" + PromptServer.instance.send_progress_text(message, unique_id) + + video_io = await download_url_to_bytesio(file_url) + if video_io is None: + error_msg = f"Failed to download video from {file_url}" + logging.error(error_msg) + raise Exception(error_msg) + return (VideoFromFile(video_io),) # A dictionary that contains all nodes you want to export with their names @@ -322,6 +495,7 @@ NODE_CLASS_MAPPINGS = { "MinimaxTextToVideoNode": MinimaxTextToVideoNode, "MinimaxImageToVideoNode": MinimaxImageToVideoNode, # "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode, + "MinimaxHailuoVideoNode": MinimaxHailuoVideoNode, } # A dictionary that contains the friendly/humanly readable titles for the nodes @@ -329,4 +503,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "MinimaxTextToVideoNode": "MiniMax Text to Video", "MinimaxImageToVideoNode": "MiniMax Image to Video", "MinimaxSubjectToVideoNode": "MiniMax Subject to Video", + "MinimaxHailuoVideoNode": "MiniMax Hailuo Video", } From 07a927517cfaf099fec3903e8973f758e62d65f9 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:29:01 +0300 Subject: [PATCH 48/71] api_nodes: add GPT-5 series models (#9325) --- comfy_api_nodes/nodes_openai.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index cbff2b2d2..674c9ede0 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -80,6 +80,9 @@ class SupportedOpenAIModel(str, Enum): gpt_4_1 = "gpt-4.1" gpt_4_1_mini = "gpt-4.1-mini" gpt_4_1_nano = "gpt-4.1-nano" + gpt_5 = "gpt-5" + gpt_5_mini = "gpt-5-mini" + gpt_5_nano = "gpt-5-nano" class OpenAIDalle2(ComfyNodeABC): From d844d8b13bfd6a83b0a7d0491aa2978ac44a6158 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:29:24 +0300 Subject: [PATCH 49/71] api_nodes: added release version of google's models (#9304) --- comfy_api_nodes/nodes_gemini.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 3751fb2a1..ba4167a50 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -46,6 +46,8 @@ class GeminiModel(str, Enum): gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06" gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" + gemini_2_5_pro = "gemini-2.5-pro" + gemini_2_5_flash = "gemini-2.5-flash" def get_gemini_endpoint( @@ -97,7 +99,7 @@ class GeminiNode(ComfyNodeABC): { "tooltip": "The Gemini model to use for generating responses.", "options": [model.value for model in GeminiModel], - "default": GeminiModel.gemini_2_5_pro_preview_05_06.value, + "default": GeminiModel.gemini_2_5_pro.value, }, ), "seed": ( From 54d8fdbed0a7b171ab8cfb02e29a7e0dc5fe78fd Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:30:06 +0300 Subject: [PATCH 50/71] feat(api-nodes): add Vidu Video nodes (#9368) --- comfy_api_nodes/nodes_vidu.py | 622 +++++++++++++++++++++++ comfy_api_nodes/util/validation_utils.py | 53 ++ nodes.py | 1 + 3 files changed, 676 insertions(+) create mode 100644 comfy_api_nodes/nodes_vidu.py diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py new file mode 100644 index 000000000..2f441948c --- /dev/null +++ b/comfy_api_nodes/nodes_vidu.py @@ -0,0 +1,622 @@ +import logging +from enum import Enum +from typing import Any, Callable, Optional, Literal, TypeVar +from typing_extensions import override + +import torch +from pydantic import BaseModel, Field + +from comfy_api.latest import ComfyExtension, io as comfy_io +from comfy_api_nodes.util.validation_utils import ( + validate_aspect_ratio_closeness, + validate_image_dimensions, + validate_image_aspect_ratio_range, + get_number_of_images, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + HttpMethod, + SynchronousOperation, + PollingOperation, + EmptyRequest, +) +from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi + + +VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" +VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video" +VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video" +VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video" +VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations" + +R = TypeVar("R") + +class VideoModelName(str, Enum): + vidu_q1 = 'viduq1' + + +class AspectRatio(str, Enum): + r_16_9 = "16:9" + r_9_16 = "9:16" + r_1_1 = "1:1" + + +class Resolution(str, Enum): + r_1080p = "1080p" + + +class MovementAmplitude(str, Enum): + auto = "auto" + small = "small" + medium = "medium" + large = "large" + + +class TaskCreationRequest(BaseModel): + model: VideoModelName = VideoModelName.vidu_q1 + prompt: Optional[str] = Field(None, max_length=1500) + duration: Optional[Literal[5]] = 5 + seed: Optional[int] = Field(0, ge=0, le=2147483647) + aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9 + resolution: Optional[Resolution] = Resolution.r_1080p + movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto + images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL") + + +class TaskStatus(str, Enum): + created = "created" + queueing = "queueing" + processing = "processing" + success = "success" + failed = "failed" + + +class TaskCreationResponse(BaseModel): + task_id: str = Field(...) + state: TaskStatus = Field(...) + created_at: str = Field(...) + code: Optional[int] = Field(None, description="Error code") + + +class TaskResult(BaseModel): + id: str = Field(..., description="Creation id") + url: str = Field(..., description="The URL of the generated results, valid for one hour") + cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour") + + +class TaskStatusResponse(BaseModel): + state: TaskStatus = Field(...) + err_code: Optional[str] = Field(None) + creations: list[TaskResult] = Field(..., description="Generated results") + + +async def poll_until_finished( + auth_kwargs: dict[str, str], + api_endpoint: ApiEndpoint[Any, R], + result_url_extractor: Optional[Callable[[R], str]] = None, + estimated_duration: Optional[int] = None, + node_id: Optional[str] = None, +) -> R: + return await PollingOperation( + poll_endpoint=api_endpoint, + completed_statuses=[TaskStatus.success.value], + failed_statuses=[TaskStatus.failed.value], + status_extractor=lambda response: response.state.value, + auth_kwargs=auth_kwargs, + result_url_extractor=result_url_extractor, + estimated_duration=estimated_duration, + node_id=node_id, + poll_interval=16.0, + max_poll_attempts=256, + ).execute() + + +def get_video_url_from_response(response) -> Optional[str]: + if response.creations: + return response.creations[0].url + return None + + +def get_video_from_response(response) -> TaskResult: + if not response.creations: + error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}" + logging.info(error_msg) + raise RuntimeError(error_msg) + logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url) + return response.creations[0] + + +async def execute_task( + vidu_endpoint: str, + auth_kwargs: Optional[dict[str, str]], + payload: TaskCreationRequest, + estimated_duration: int, + node_id: str, +) -> R: + response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=vidu_endpoint, + method=HttpMethod.POST, + request_model=TaskCreationRequest, + response_model=TaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + if response.state == TaskStatus.failed: + error_msg = f"Vidu request failed. Code: {response.code}" + logging.error(error_msg) + raise RuntimeError(error_msg) + return await poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=VIDU_GET_GENERATION_STATUS % response.task_id, + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=TaskStatusResponse, + ), + result_url_extractor=get_video_url_from_response, + estimated_duration=estimated_duration, + node_id=node_id, + ) + + +class ViduTextToVideoNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ViduTextToVideoNode", + display_name="Vidu Text To Video Generation", + category="api node/video/Vidu", + description="Generate video from text prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in VideoModelName], + default=VideoModelName.vidu_q1.value, + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation", + ), + comfy_io.Int.Input( + "duration", + default=5, + min=5, + max=5, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=[model.value for model in AspectRatio], + default=AspectRatio.r_16_9.value, + tooltip="The aspect ratio of the output video", + optional=True, + ), + comfy_io.Combo.Input( + "resolution", + options=[model.value for model in Resolution], + default=Resolution.r_1080p.value, + tooltip="Supported values may vary by model & duration", + optional=True, + ), + comfy_io.Combo.Input( + "movement_amplitude", + options=[model.value for model in MovementAmplitude], + default=MovementAmplitude.auto.value, + tooltip="The movement amplitude of objects in the frame", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + movement_amplitude: str, + ) -> comfy_io.NodeOutput: + if not prompt: + raise ValueError("The prompt field is required and cannot be empty.") + payload = TaskCreationRequest( + model_name=model, + prompt=prompt, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + movement_amplitude=movement_amplitude, + ) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + + +class ViduImageToVideoNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ViduImageToVideoNode", + display_name="Vidu Image To Video Generation", + category="api node/video/Vidu", + description="Generate video from image and optional prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in VideoModelName], + default=VideoModelName.vidu_q1.value, + tooltip="Model name", + ), + comfy_io.Image.Input( + "image", + tooltip="An image to be used as the start frame of the generated video", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="A textual description for video generation", + optional=True, + ), + comfy_io.Int.Input( + "duration", + default=5, + min=5, + max=5, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + comfy_io.Combo.Input( + "resolution", + options=[model.value for model in Resolution], + default=Resolution.r_1080p.value, + tooltip="Supported values may vary by model & duration", + optional=True, + ), + comfy_io.Combo.Input( + "movement_amplitude", + options=[model.value for model in MovementAmplitude], + default=MovementAmplitude.auto.value, + tooltip="The movement amplitude of objects in the frame", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: torch.Tensor, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> comfy_io.NodeOutput: + if get_number_of_images(image) > 1: + raise ValueError("Only one input image is allowed.") + validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) + payload = TaskCreationRequest( + model_name=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + ) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + payload.images = await upload_images_to_comfyapi( + image, + max_images=1, + mime_type="image/png", + auth_kwargs=auth, + ) + results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + + +class ViduReferenceVideoNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ViduReferenceVideoNode", + display_name="Vidu Reference To Video Generation", + category="api node/video/Vidu", + description="Generate video from multiple images and prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in VideoModelName], + default=VideoModelName.vidu_q1.value, + tooltip="Model name", + ), + comfy_io.Image.Input( + "images", + tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation", + ), + comfy_io.Int.Input( + "duration", + default=5, + min=5, + max=5, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=[model.value for model in AspectRatio], + default=AspectRatio.r_16_9.value, + tooltip="The aspect ratio of the output video", + optional=True, + ), + comfy_io.Combo.Input( + "resolution", + options=[model.value for model in Resolution], + default=Resolution.r_1080p.value, + tooltip="Supported values may vary by model & duration", + optional=True, + ), + comfy_io.Combo.Input( + "movement_amplitude", + options=[model.value for model in MovementAmplitude], + default=MovementAmplitude.auto.value, + tooltip="The movement amplitude of objects in the frame", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + images: torch.Tensor, + prompt: str, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + movement_amplitude: str, + ) -> comfy_io.NodeOutput: + if not prompt: + raise ValueError("The prompt field is required and cannot be empty.") + a = get_number_of_images(images) + if a > 7: + raise ValueError("Too many images, maximum allowed is 7.") + for image in images: + validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) + validate_image_dimensions(image, min_width=128, min_height=128) + payload = TaskCreationRequest( + model_name=model, + prompt=prompt, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + movement_amplitude=movement_amplitude, + ) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + payload.images = await upload_images_to_comfyapi( + images, + max_images=7, + mime_type="image/png", + auth_kwargs=auth, + ) + results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + + +class ViduStartEndToVideoNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ViduStartEndToVideoNode", + display_name="Vidu Start End To Video Generation", + category="api node/video/Vidu", + description="Generate a video from start and end frames and a prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in VideoModelName], + default=VideoModelName.vidu_q1.value, + tooltip="Model name", + ), + comfy_io.Image.Input( + "first_frame", + tooltip="Start frame", + ), + comfy_io.Image.Input( + "end_frame", + tooltip="End frame", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation", + optional=True, + ), + comfy_io.Int.Input( + "duration", + default=5, + min=5, + max=5, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + comfy_io.Combo.Input( + "resolution", + options=[model.value for model in Resolution], + default=Resolution.r_1080p.value, + tooltip="Supported values may vary by model & duration", + optional=True, + ), + comfy_io.Combo.Input( + "movement_amplitude", + options=[model.value for model in MovementAmplitude], + default=MovementAmplitude.auto.value, + tooltip="The movement amplitude of objects in the frame", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + first_frame: torch.Tensor, + end_frame: torch.Tensor, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> comfy_io.NodeOutput: + validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) + payload = TaskCreationRequest( + model_name=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + ) + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + payload.images = [ + (await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0] + for frame in (first_frame, end_frame) + ] + results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + + +class ViduExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + ViduTextToVideoNode, + ViduImageToVideoNode, + ViduReferenceVideoNode, + ViduStartEndToVideoNode, + ] + +async def comfy_entrypoint() -> ViduExtension: + return ViduExtension() diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index 031b9fbd3..606b794bf 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -53,6 +53,53 @@ def validate_image_aspect_ratio( ) +def validate_image_aspect_ratio_range( + image: torch.Tensor, + min_ratio: tuple[float, float], # e.g. (1, 4) + max_ratio: tuple[float, float], # e.g. (4, 1) + *, + strict: bool = True, # True -> (min, max); False -> [min, max] +) -> float: + a1, b1 = min_ratio + a2, b2 = max_ratio + if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0: + raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).") + lo, hi = (a1 / b1), (a2 / b2) + if lo > hi: + lo, hi = hi, lo + a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text + w, h = get_image_dimensions(image) + if w <= 0 or h <= 0: + raise ValueError(f"Invalid image dimensions: {w}x{h}") + ar = w / h + ok = (lo < ar < hi) if strict else (lo <= ar <= hi) + if not ok: + op = "<" if strict else "≤" + raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}") + return ar + + +def validate_aspect_ratio_closeness( + start_img, + end_img, + min_rel: float, + max_rel: float, + *, + strict: bool = False, # True => exclusive, False => inclusive +) -> None: + w1, h1 = get_image_dimensions(start_img) + w2, h2 = get_image_dimensions(end_img) + if min(w1, h1, w2, h2) <= 0: + raise ValueError("Invalid image dimensions") + ar1 = w1 / h1 + ar2 = w2 / h2 + # Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1) + closeness = max(ar1, ar2) / min(ar1, ar2) + limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25 + if (closeness >= limit) if strict else (closeness > limit): + raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.") + + def validate_video_dimensions( video: VideoInput, min_width: Optional[int] = None, @@ -98,3 +145,9 @@ def validate_video_duration( raise ValueError( f"Video duration must be at most {max_duration}s, got {duration}s" ) + + +def get_number_of_images(images): + if isinstance(images, torch.Tensor): + return images.shape[0] if images.ndim >= 4 else 1 + return len(images) diff --git a/nodes.py b/nodes.py index b3fa9c51a..35dda1b19 100644 --- a/nodes.py +++ b/nodes.py @@ -2351,6 +2351,7 @@ async def init_builtin_api_nodes(): "nodes_moonvalley.py", "nodes_rodin.py", "nodes_gemini.py", + "nodes_vidu.py", ] if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): From bddd69618bf4463209c3681babfcbebd9b9aed85 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 19 Aug 2025 13:49:01 -0700 Subject: [PATCH 51/71] Change the TextEncodeQwenImageEdit node to use logic closer to reference. (#9432) --- comfy_extras/nodes_qwen.py | 37 +++++++++++-------------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index b5088fae2..fff89556f 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -1,25 +1,6 @@ import node_helpers import comfy.utils - -PREFERRED_QWENIMAGE_RESOLUTIONS = [ - (672, 1568), - (688, 1504), - (720, 1456), - (752, 1392), - (800, 1328), - (832, 1248), - (880, 1184), - (944, 1104), - (1024, 1024), - (1104, 944), - (1184, 880), - (1248, 832), - (1328, 800), - (1392, 752), - (1456, 720), - (1504, 688), - (1568, 672), -] +import math class TextEncodeQwenImageEdit: @@ -42,13 +23,17 @@ class TextEncodeQwenImageEdit: if image is None: images = [] else: - images = [image] + samples = image.movedim(-1, 1) + total = int(1024 * 1024) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled") + image = s.movedim(1, -1) + images = [image[:, :, :, :3]] if vae is not None: - width = image.shape[2] - height = image.shape[1] - aspect_ratio = width / height - _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS) - image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) ref_latent = vae.encode(image[:, :, :, :3]) tokens = clip.tokenize(prompt, images=images) From dfa791eb4bfcaac3eb9b2b33fa15ae5a25589bb8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 19 Aug 2025 17:47:42 -0700 Subject: [PATCH 52/71] Rope fix for qwen vl. (#9435) --- comfy/text_encoders/llama.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 9d90d5a61..4c976058f 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -27,6 +27,7 @@ class Llama2Config: rms_norm_add = False mlp_activation = "silu" qkv_bias = False + rope_dims = None @dataclass class Qwen25_3BConfig: @@ -44,6 +45,7 @@ class Qwen25_3BConfig: rms_norm_add = False mlp_activation = "silu" qkv_bias = True + rope_dims = None @dataclass class Qwen25_7BVLI_Config: @@ -61,6 +63,7 @@ class Qwen25_7BVLI_Config: rms_norm_add = False mlp_activation = "silu" qkv_bias = True + rope_dims = [16, 24, 24] @dataclass class Gemma2_2B_Config: @@ -78,6 +81,7 @@ class Gemma2_2B_Config: rms_norm_add = True mlp_activation = "gelu_pytorch_tanh" qkv_bias = False + rope_dims = None class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): @@ -102,7 +106,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def precompute_freqs_cis(head_dim, position_ids, theta, device=None): +def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None): theta_numerator = torch.arange(0, head_dim, 2, device=device).float() inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) @@ -112,12 +116,20 @@ def precompute_freqs_cis(head_dim, position_ids, theta, device=None): emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() + if rope_dims is not None and position_ids.shape[0] > 1: + mrope_section = rope_dims * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0) + else: + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + return (cos, sin) def apply_rope(xq, xk, freqs_cis): - cos = freqs_cis[0].unsqueeze(1) - sin = freqs_cis[1].unsqueeze(1) + cos = freqs_cis[0] + sin = freqs_cis[1] q_embed = (xq * cos) + (rotate_half(xq) * sin) k_embed = (xk * cos) + (rotate_half(xk) * sin) return q_embed, k_embed @@ -292,6 +304,7 @@ class Llama2_(nn.Module): freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, self.config.rope_theta, + self.config.rope_dims, device=x.device) mask = None From 7cd2c4bd6ab20f35a6bb1b1f2252c3ea16da4777 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 19 Aug 2025 21:45:27 -0700 Subject: [PATCH 53/71] Qwen rotary embeddings should now match reference code. (#9437) --- comfy/ldm/qwen_image/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index a3c726299..bf3940313 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -349,8 +349,8 @@ class QwenImageTransformer2DModel(nn.Module): img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape def forward( @@ -396,7 +396,7 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size), ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size))) + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) From 5a8f502db5889873ffa13132b603b7b6daac605a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 19 Aug 2025 22:08:11 -0700 Subject: [PATCH 54/71] Disable prompt weights for qwen. (#9438) --- comfy/sd1_clip.py | 5 ++++- comfy/text_encoders/qwen_image.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 1e8adbe69..f8a7c2a1b 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -534,7 +534,10 @@ class SDTokenizer: min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) text = escape_important(text) - parsed_weights = token_weights(text, 1.0) + if kwargs.get("disable_weights", False): + parsed_weights = [(text, 1.0)] + else: + parsed_weights = token_weights(text, 1.0) # tokenize words tokens = [] diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index f07318d6c..6646b1003 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -15,7 +15,7 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer) self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" - self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image \\(color, shape, size, texture, objects, background\\), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): if llama_template is None: @@ -25,7 +25,7 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer): llama_text = self.llama_template.format(text) else: llama_text = llama_template.format(text) - tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs) + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) key_name = next(iter(tokens)) embed_count = 0 qwen_tokens = tokens[key_name] From 8d38ea3bbf7e77ed7e7aee401b157dab211c5307 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 19 Aug 2025 23:58:54 -0700 Subject: [PATCH 55/71] Fix bf16 precision issue with qwen image embeddings. (#9441) --- comfy/ldm/qwen_image/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index bf3940313..49f66b90a 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -347,7 +347,7 @@ class QwenImageTransformer2DModel(nn.Module): h_offset = ((h_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device) img_ids[:, :, 0] = img_ids[:, :, 1] + index img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) @@ -397,9 +397,10 @@ class QwenImageTransformer2DModel(nn.Module): img_ids = torch.cat([img_ids, kontext_ids], dim=1) txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) - txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states) From 2f52e8f05f2039dd67e9b9783b8397350a548b95 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Wed, 20 Aug 2025 15:15:09 +0800 Subject: [PATCH 56/71] Bump template to 0.1.62 (#9419) * Bump template to 0.1.61 * Bump template to 0.1.62 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c7a5c47ab..8d928d826 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.9 -comfyui-workflow-templates==0.1.60 +comfyui-workflow-templates==0.1.62 comfyui-embedded-docs==0.2.6 torch torchsde From 7139d6d93fc7b5481a69b687080bd36f7b531c46 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 Aug 2025 03:15:30 -0400 Subject: [PATCH 57/71] ComfyUI version 0.3.51 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 29ec07ca6..65f06cf37 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.50" +__version__ = "0.3.51" diff --git a/pyproject.toml b/pyproject.toml index 659b5730a..ecbf04303 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.50" +version = "0.3.51" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" From fe01885acf892de636b1b2743903812099bd42e3 Mon Sep 17 00:00:00 2001 From: Harel Cain Date: Wed, 20 Aug 2025 09:33:10 +0200 Subject: [PATCH 58/71] LTXV: fix key frame noise mask dimensions for when real noise mask exists (#9425) --- comfy_extras/nodes_lt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index b5058667a..f82337a67 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -166,7 +166,7 @@ class LTXVAddGuide: negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) mask = torch.full( - (noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1), + (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), 1.0 - strength, dtype=noise_mask.dtype, device=noise_mask.device, From e73a9dbe30434280c69d852ea78cc4bf88bfd501 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 Aug 2025 14:34:13 -0700 Subject: [PATCH 59/71] Add that qwen edit model is supported to readme. (#9463) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index fa99a8cbe..79a8a8c79 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) - [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11) + - [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model) - Video Models - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) From 0963493a9c3b6565f8537288a0fb90991391ec41 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:26:37 -0700 Subject: [PATCH 60/71] Support for Qwen Diffsynth Controlnets canny and depth. (#9465) These are not real controlnets but actually a patch on the model so they will be treated as such. Put them in the models/model_patches/ folder. Use the new ModelPatchLoader and QwenImageDiffsynthControlnet nodes. --- comfy/ldm/qwen_image/model.py | 7 + comfy/model_management.py | 8 +- comfy/model_patcher.py | 27 ++++ comfy_api/latest/_io.py | 4 + comfy_extras/nodes_model_patch.py | 138 ++++++++++++++++++++ models/model_patches/put_model_patches_here | 0 nodes.py | 1 + 7 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_model_patch.py create mode 100644 models/model_patches/put_model_patches_here diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 49f66b90a..2503583cb 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -416,6 +416,7 @@ class QwenImageTransformer2DModel(nn.Module): ) patches_replace = transformer_options.get("patches_replace", {}) + patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) for i, block in enumerate(self.transformer_blocks): @@ -436,6 +437,12 @@ class QwenImageTransformer2DModel(nn.Module): image_rotary_emb=image_rotary_emb, ) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2a9f18068..d08aee1fe 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -593,7 +593,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu else: minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory()) - models = set(models) + models_temp = set() + for m in models: + models_temp.add(m) + for mm in m.model_patches_models(): + models_temp.add(mm) + + models = models_temp models_to_load = [] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 52e76b5f3..a944cb421 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -430,6 +430,9 @@ class ModelPatcher: def set_model_forward_timestep_embed_patch(self, patch): self.set_model_patch(patch, "forward_timestep_embed_patch") + def set_model_double_block_patch(self, patch): + self.set_model_patch(patch, "double_block") + def add_object_patch(self, name, obj): self.object_patches[name] = obj @@ -486,6 +489,30 @@ class ModelPatcher: if hasattr(wrap_func, "to"): self.model_options["model_function_wrapper"] = wrap_func.to(device) + def model_patches_models(self): + to = self.model_options["transformer_options"] + models = [] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "models"): + models += patch_list[i].models() + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], "models"): + models += patch_list[k].models() + if "model_function_wrapper" in self.model_options: + wrap_func = self.model_options["model_function_wrapper"] + if hasattr(wrap_func, "models"): + models += wrap_func.models() + + return models + def model_dtype(self): if hasattr(self.model, "get_dtype"): return self.model.get_dtype() diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index ec1efb51d..a3a21facc 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -726,6 +726,10 @@ class SEGS(ComfyTypeIO): class AnyType(ComfyTypeIO): Type = Any +@comfytype(io_type="MODEL_PATCH") +class MODEL_PATCH(ComfyTypeIO): + Type = Any + @comfytype(io_type="COMFY_MULTITYPED_V3") class MultiType: Type = Any diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py new file mode 100644 index 000000000..bb239bc45 --- /dev/null +++ b/comfy_extras/nodes_model_patch.py @@ -0,0 +1,138 @@ +import torch +import folder_paths +import comfy.utils +import comfy.ops +import comfy.model_management +import comfy.ldm.common_dit +import comfy.latent_formats + + +class BlockWiseControlBlock(torch.nn.Module): + # [linear, gelu, linear] + def __init__(self, dim: int = 3072, device=None, dtype=None, operations=None): + super().__init__() + self.x_rms = operations.RMSNorm(dim, eps=1e-6) + self.y_rms = operations.RMSNorm(dim, eps=1e-6) + self.input_proj = operations.Linear(dim, dim) + self.act = torch.nn.GELU() + self.output_proj = operations.Linear(dim, dim) + + def forward(self, x, y): + x, y = self.x_rms(x), self.y_rms(y) + x = self.input_proj(x + y) + x = self.act(x) + x = self.output_proj(x) + return x + + +class QwenImageBlockWiseControlNet(torch.nn.Module): + def __init__( + self, + num_layers: int = 60, + in_dim: int = 64, + additional_in_dim: int = 0, + dim: int = 3072, + device=None, dtype=None, operations=None + ): + super().__init__() + self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype) + self.controlnet_blocks = torch.nn.ModuleList( + [ + BlockWiseControlBlock(dim, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ] + ) + + def process_input_latent_image(self, latent_image): + latent_image = comfy.latent_formats.Wan21().process_in(latent_image) + patch_size = 2 + hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) + hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + return self.img_in(hidden_states) + + def control_block(self, img, controlnet_conditioning, block_id): + return self.controlnet_blocks[block_id](img, controlnet_conditioning) + + +class ModelPatchLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ), + }} + RETURN_TYPES = ("MODEL_PATCH",) + FUNCTION = "load_model_patch" + EXPERIMENTAL = True + + CATEGORY = "advanced/loaders" + + def load_model_patch(self, name): + model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name) + sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True) + dtype = comfy.utils.weight_dtype(sd) + # TODO: this node will work with more types of model patches + model = QwenImageBlockWiseControlNet(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + model.load_state_dict(sd) + model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + return (model,) + + +class DiffSynthCnetPatch: + def __init__(self, model_patch, vae, image, strength): + self.encoded_image = model_patch.model.process_input_latent_image(vae.encode(image)) + self.model_patch = model_patch + self.vae = vae + self.image = image + self.strength = strength + + def __call__(self, kwargs): + x = kwargs.get("x") + img = kwargs.get("img") + block_index = kwargs.get("block_index") + if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]: + spacial_compression = self.vae.spacial_compression_encode() + image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + self.encoded_image = self.model_patch.model.process_input_latent_image(self.vae.encode(image_scaled.movedim(1, -1))) + comfy.model_management.load_models_gpu(loaded_models) + + img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength) + kwargs['img'] = img + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.encoded_image = self.encoded_image.to(device_or_dtype) + return self + + def models(self): + return [self.model_patch] + +class QwenImageDiffsynthControlnet: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "model_patch": ("MODEL_PATCH",), + "vae": ("VAE",), + "image": ("IMAGE",), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "diffsynth_controlnet" + EXPERIMENTAL = True + + CATEGORY = "advanced/loaders/qwen" + + def diffsynth_controlnet(self, model, model_patch, vae, image, strength): + model_patched = model.clone() + image = image[:, :, :, :3] + model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength)) + return (model_patched,) + + +NODE_CLASS_MAPPINGS = { + "ModelPatchLoader": ModelPatchLoader, + "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, +} diff --git a/models/model_patches/put_model_patches_here b/models/model_patches/put_model_patches_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 35dda1b19..9681750d3 100644 --- a/nodes.py +++ b/nodes.py @@ -2322,6 +2322,7 @@ async def init_builtin_extra_nodes(): "nodes_tcfg.py", "nodes_context_windows.py", "nodes_qwen.py", + "nodes_model_patch.py" ] import_failed = [] From 0737b7e0d245de20192064da4888debbef3241c2 Mon Sep 17 00:00:00 2001 From: saurabh-pingale Date: Thu, 21 Aug 2025 07:57:57 +0530 Subject: [PATCH 61/71] fix(userdata): catch invalid workflow filenames (#9434) (#9445) --- app/user_manager.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/app/user_manager.py b/app/user_manager.py index 0ec3e46ea..a2d376c0c 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -363,10 +363,17 @@ class UserManager(): if not overwrite and os.path.exists(path): return web.Response(status=409, text="File already exists") - body = await request.read() + try: + body = await request.read() - with open(path, "wb") as f: - f.write(body) + with open(path, "wb") as f: + f.write(body) + except OSError as e: + logging.warning(f"Error saving file '{path}': {e}") + return web.Response( + status=400, + reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|" + ) user_path = self.get_request_user_filepath(request, None) if full_info: From 9fa1036f60b5264302072453be524aa55928bbaf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 Aug 2025 20:09:35 -0700 Subject: [PATCH 62/71] Forgot this. (#9470) --- folder_paths.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/folder_paths.py b/folder_paths.py index 9ec952940..b34af39e8 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -46,6 +46,8 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")] folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) +folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patches")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") From 1b2de2642d38099acdde7c460d133d93e91074f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 Aug 2025 21:33:49 -0700 Subject: [PATCH 63/71] Support diffsynth inpaint controlnet (model patch). (#9471) --- comfy_extras/nodes_model_patch.py | 39 ++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index bb239bc45..3eaada9bc 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -35,6 +35,7 @@ class QwenImageBlockWiseControlNet(torch.nn.Module): device=None, dtype=None, operations=None ): super().__init__() + self.additional_in_dim = additional_in_dim self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype) self.controlnet_blocks = torch.nn.ModuleList( [ @@ -44,7 +45,7 @@ class QwenImageBlockWiseControlNet(torch.nn.Module): ) def process_input_latent_image(self, latent_image): - latent_image = comfy.latent_formats.Wan21().process_in(latent_image) + latent_image[:, :16] = comfy.latent_formats.Wan21().process_in(latent_image[:, :16]) patch_size = 2 hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size)) orig_shape = hidden_states.shape @@ -73,19 +74,33 @@ class ModelPatchLoader: sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True) dtype = comfy.utils.weight_dtype(sd) # TODO: this node will work with more types of model patches - model = QwenImageBlockWiseControlNet(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + additional_in_dim = sd["img_in.weight"].shape[1] - 64 + model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) return (model,) class DiffSynthCnetPatch: - def __init__(self, model_patch, vae, image, strength): - self.encoded_image = model_patch.model.process_input_latent_image(vae.encode(image)) + def __init__(self, model_patch, vae, image, strength, mask=None): self.model_patch = model_patch self.vae = vae self.image = image self.strength = strength + self.mask = mask + self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image)) + + def encode_latent_cond(self, image): + latent_image = self.vae.encode(image) + if self.model_patch.model.additional_in_dim > 0: + if self.mask is None: + mask_ = torch.ones_like(latent_image)[:, :self.model_patch.model.additional_in_dim // 4] + else: + mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none") + + return torch.cat([latent_image, mask_], dim=1) + else: + return latent_image def __call__(self, kwargs): x = kwargs.get("x") @@ -95,7 +110,7 @@ class DiffSynthCnetPatch: spacial_compression = self.vae.spacial_compression_encode() image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") loaded_models = comfy.model_management.loaded_models(only_currently_used=True) - self.encoded_image = self.model_patch.model.process_input_latent_image(self.vae.encode(image_scaled.movedim(1, -1))) + self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1))) comfy.model_management.load_models_gpu(loaded_models) img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength) @@ -118,17 +133,25 @@ class QwenImageDiffsynthControlnet: "vae": ("VAE",), "image": ("IMAGE",), "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }} + }, + "optional": {"mask": ("MASK",)}} RETURN_TYPES = ("MODEL",) FUNCTION = "diffsynth_controlnet" EXPERIMENTAL = True CATEGORY = "advanced/loaders/qwen" - def diffsynth_controlnet(self, model, model_patch, vae, image, strength): + def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None): model_patched = model.clone() image = image[:, :, :, :3] - model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength)) + if mask is not None: + if mask.ndim == 3: + mask = mask.unsqueeze(1) + if mask.ndim == 4: + mask = mask.unsqueeze(2) + mask = 1.0 - mask + + model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) return (model_patched,) From bc49106837b627eb657fc86f2e475770ac5ce68a Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 Aug 2025 05:03:57 +0300 Subject: [PATCH 64/71] convert String nodes to V3 schema (#9370) --- comfy_extras/nodes_string.py | 449 ++++++++++++++++++----------------- 1 file changed, 237 insertions(+), 212 deletions(-) diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index b1a8ceef0..571d89f62 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -1,77 +1,91 @@ import re +from typing_extensions import override -from comfy.comfy_types.node_typing import IO +from comfy_api.latest import ComfyExtension, io -class StringConcatenate(): + +class StringConcatenate(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string_a": (IO.STRING, {"multiline": True}), - "string_b": (IO.STRING, {"multiline": True}), - "delimiter": (IO.STRING, {"multiline": False, "default": ""}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringConcatenate", + display_name="Concatenate", + category="utils/string", + inputs=[ + io.String.Input("string_a", multiline=True), + io.String.Input("string_b", multiline=True), + io.String.Input("delimiter", multiline=False, default=""), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string_a, string_b, delimiter, **kwargs): - return delimiter.join((string_a, string_b)), - -class StringSubstring(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "start": (IO.INT, {}), - "end": (IO.INT, {}), - } - } + def execute(cls, string_a, string_b, delimiter): + return io.NodeOutput(delimiter.join((string_a, string_b))) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, start, end, **kwargs): - return string[start:end], - -class StringLength(): +class StringSubstring(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringSubstring", + display_name="Substring", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.Int.Input("start"), + io.Int.Input("end"), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.INT,) - RETURN_NAMES = ("length",) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, **kwargs): - length = len(string) - - return length, - -class CaseConverter(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]}) - } - } + def execute(cls, string, start, end): + return io.NodeOutput(string[start:end]) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, mode, **kwargs): +class StringLength(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StringLength", + display_name="Length", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + ], + outputs=[ + io.Int.Output(display_name="length"), + ] + ) + + @classmethod + def execute(cls, string): + return io.NodeOutput(len(string)) + + +class CaseConverter(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CaseConverter", + display_name="Case Converter", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]), + ], + outputs=[ + io.String.Output(), + ] + ) + + @classmethod + def execute(cls, string, mode): if mode == "UPPERCASE": result = string.upper() elif mode == "lowercase": @@ -83,24 +97,27 @@ class CaseConverter(): else: result = string - return result, + return io.NodeOutput(result) -class StringTrim(): +class StringTrim(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringTrim", + display_name="Trim", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.Combo.Input("mode", options=["Both", "Left", "Right"]), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, mode, **kwargs): + @classmethod + def execute(cls, string, mode): if mode == "Both": result = string.strip() elif mode == "Left": @@ -110,70 +127,78 @@ class StringTrim(): else: result = string - return result, + return io.NodeOutput(result) -class StringReplace(): + +class StringReplace(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "find": (IO.STRING, {"multiline": True}), - "replace": (IO.STRING, {"multiline": True}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringReplace", + display_name="Replace", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("find", multiline=True), + io.String.Input("replace", multiline=True), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, find, replace, **kwargs): - result = string.replace(find, replace) - return result, - - -class StringContains(): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "substring": (IO.STRING, {"multiline": True}), - "case_sensitive": (IO.BOOLEAN, {"default": True}) - } - } + def execute(cls, string, find, replace): + return io.NodeOutput(string.replace(find, replace)) - RETURN_TYPES = (IO.BOOLEAN,) - RETURN_NAMES = ("contains",) - FUNCTION = "execute" - CATEGORY = "utils/string" - def execute(self, string, substring, case_sensitive, **kwargs): +class StringContains(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StringContains", + display_name="Contains", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("substring", multiline=True), + io.Boolean.Input("case_sensitive", default=True), + ], + outputs=[ + io.Boolean.Output(display_name="contains"), + ] + ) + + @classmethod + def execute(cls, string, substring, case_sensitive): if case_sensitive: contains = substring in string else: contains = substring.lower() in string.lower() - return contains, + return io.NodeOutput(contains) -class StringCompare(): +class StringCompare(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string_a": (IO.STRING, {"multiline": True}), - "string_b": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}), - "case_sensitive": (IO.BOOLEAN, {"default": True}) - } - } + def define_schema(cls): + return io.Schema( + node_id="StringCompare", + display_name="Compare", + category="utils/string", + inputs=[ + io.String.Input("string_a", multiline=True), + io.String.Input("string_b", multiline=True), + io.Combo.Input("mode", options=["Starts With", "Ends With", "Equal"]), + io.Boolean.Input("case_sensitive", default=True), + ], + outputs=[ + io.Boolean.Output(), + ] + ) - RETURN_TYPES = (IO.BOOLEAN,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string_a, string_b, mode, case_sensitive, **kwargs): + @classmethod + def execute(cls, string_a, string_b, mode, case_sensitive): if case_sensitive: a = string_a b = string_b @@ -182,31 +207,34 @@ class StringCompare(): b = string_b.lower() if mode == "Equal": - return a == b, + return io.NodeOutput(a == b) elif mode == "Starts With": - return a.startswith(b), + return io.NodeOutput(a.startswith(b)) elif mode == "Ends With": - return a.endswith(b), + return io.NodeOutput(a.endswith(b)) -class RegexMatch(): + +class RegexMatch(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False}) - } - } + def define_schema(cls): + return io.Schema( + node_id="RegexMatch", + display_name="Regex Match", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("regex_pattern", multiline=True), + io.Boolean.Input("case_insensitive", default=True), + io.Boolean.Input("multiline", default=False), + io.Boolean.Input("dotall", default=False), + ], + outputs=[ + io.Boolean.Output(display_name="matches"), + ] + ) - RETURN_TYPES = (IO.BOOLEAN,) - RETURN_NAMES = ("matches",) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs): + @classmethod + def execute(cls, string, regex_pattern, case_insensitive, multiline, dotall): flags = 0 if case_insensitive: @@ -223,29 +251,32 @@ class RegexMatch(): except re.error: result = False - return result, + return io.NodeOutput(result) -class RegexExtract(): +class RegexExtract(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}), - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False}), - "group_index": (IO.INT, {"default": 1, "min": 0, "max": 100}) - } - } + def define_schema(cls): + return io.Schema( + node_id="RegexExtract", + display_name="Regex Extract", + category="utils/string", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("regex_pattern", multiline=True), + io.Combo.Input("mode", options=["First Match", "All Matches", "First Group", "All Groups"]), + io.Boolean.Input("case_insensitive", default=True), + io.Boolean.Input("multiline", default=False), + io.Boolean.Input("dotall", default=False), + io.Int.Input("group_index", default=1, min=0, max=100), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs): + @classmethod + def execute(cls, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index): join_delimiter = "\n" flags = 0 @@ -294,32 +325,33 @@ class RegexExtract(): except re.error: result = "" - return result, + return io.NodeOutput(result) -class RegexReplace(): - DESCRIPTION = "Find and replace text using regex patterns." +class RegexReplace(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "string": (IO.STRING, {"multiline": True}), - "regex_pattern": (IO.STRING, {"multiline": True}), - "replace": (IO.STRING, {"multiline": True}), - }, - "optional": { - "case_insensitive": (IO.BOOLEAN, {"default": True}), - "multiline": (IO.BOOLEAN, {"default": False}), - "dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}), - "count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}), - } - } + def define_schema(cls): + return io.Schema( + node_id="RegexReplace", + display_name="Regex Replace", + category="utils/string", + description="Find and replace text using regex patterns.", + inputs=[ + io.String.Input("string", multiline=True), + io.String.Input("regex_pattern", multiline=True), + io.String.Input("replace", multiline=True), + io.Boolean.Input("case_insensitive", default=True, optional=True), + io.Boolean.Input("multiline", default=False, optional=True), + io.Boolean.Input("dotall", default=False, optional=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."), + io.Int.Input("count", default=0, min=0, max=100, optional=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."), + ], + outputs=[ + io.String.Output(), + ] + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/string" - - def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs): + @classmethod + def execute(cls, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0): flags = 0 if case_insensitive: @@ -329,32 +361,25 @@ class RegexReplace(): if dotall: flags |= re.DOTALL result = re.sub(regex_pattern, replace, string, count=count, flags=flags) - return result, + return io.NodeOutput(result) -NODE_CLASS_MAPPINGS = { - "StringConcatenate": StringConcatenate, - "StringSubstring": StringSubstring, - "StringLength": StringLength, - "CaseConverter": CaseConverter, - "StringTrim": StringTrim, - "StringReplace": StringReplace, - "StringContains": StringContains, - "StringCompare": StringCompare, - "RegexMatch": RegexMatch, - "RegexExtract": RegexExtract, - "RegexReplace": RegexReplace, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "StringConcatenate": "Concatenate", - "StringSubstring": "Substring", - "StringLength": "Length", - "CaseConverter": "Case Converter", - "StringTrim": "Trim", - "StringReplace": "Replace", - "StringContains": "Contains", - "StringCompare": "Compare", - "RegexMatch": "Regex Match", - "RegexExtract": "Regex Extract", - "RegexReplace": "Regex Replace", -} +class StringExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + StringConcatenate, + StringSubstring, + StringLength, + CaseConverter, + StringTrim, + StringReplace, + StringContains, + StringCompare, + RegexMatch, + RegexExtract, + RegexReplace, + ] + +async def comfy_entrypoint() -> StringExtension: + return StringExtension() From bab08f40d10c8737c3424e35bbff873bcb2333bd Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 Aug 2025 05:05:36 +0300 Subject: [PATCH 65/71] v3 nodes (part a) (#9149) --- comfy_extras/nodes_ace.py | 80 +++++++----- comfy_extras/nodes_advanced_samplers.py | 88 +++++++------ comfy_extras/nodes_apg.py | 72 +++++++---- comfy_extras/nodes_attention_multiply.py | 154 ++++++++++++++--------- 4 files changed, 239 insertions(+), 155 deletions(-) diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index cbfec15a2..1409233c9 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -1,49 +1,63 @@ import torch +from typing_extensions import override + import comfy.model_management import node_helpers +from comfy_api.latest import ComfyExtension, io -class TextEncodeAceStepAudio: + +class TextEncodeAceStepAudio(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "tags": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" + def define_schema(cls): + return io.Schema( + node_id="TextEncodeAceStepAudio", + category="conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("tags", multiline=True, dynamic_prompts=True), + io.String.Input("lyrics", multiline=True, dynamic_prompts=True), + io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Conditioning.Output()], + ) - CATEGORY = "conditioning" - - def encode(self, clip, tags, lyrics, lyrics_strength): + @classmethod + def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput: tokens = clip.tokenize(tags, lyrics=lyrics) conditioning = clip.encode_from_tokens_scheduled(tokens) conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) - return (conditioning, ) + return io.NodeOutput(conditioning) -class EmptyAceStepLatentAudio: - def __init__(self): - self.device = comfy.model_management.intermediate_device() +class EmptyAceStepLatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyAceStepLatentAudio", + category="latent/audio", + inputs=[ + io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), + io.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[io.Latent.Output()], + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" - - CATEGORY = "latent/audio" - - def generate(self, seconds, batch_size): + def execute(cls, seconds, batch_size) -> io.NodeOutput: length = int(seconds * 44100 / 512 / 8) - latent = torch.zeros([batch_size, 8, 16, length], device=self.device) - return ({"samples": latent, "type": "audio"}, ) + latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent, "type": "audio"}) -NODE_CLASS_MAPPINGS = { - "TextEncodeAceStepAudio": TextEncodeAceStepAudio, - "EmptyAceStepLatentAudio": EmptyAceStepLatentAudio, -} +class AceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeAceStepAudio, + EmptyAceStepLatentAudio, + ] + +async def comfy_entrypoint() -> AceExtension: + return AceExtension() diff --git a/comfy_extras/nodes_advanced_samplers.py b/comfy_extras/nodes_advanced_samplers.py index 5fbb096fb..5532ffe6a 100644 --- a/comfy_extras/nodes_advanced_samplers.py +++ b/comfy_extras/nodes_advanced_samplers.py @@ -1,8 +1,13 @@ +import numpy as np +import torch +from tqdm.auto import trange +from typing_extensions import override + +import comfy.model_patcher import comfy.samplers import comfy.utils -import torch -import numpy as np -from tqdm.auto import trange +from comfy.k_diffusion.sampling import to_d +from comfy_api.latest import ComfyExtension, io @torch.no_grad() @@ -33,30 +38,29 @@ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable return x -class SamplerLCMUpscale: - upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"] +class SamplerLCMUpscale(io.ComfyNode): + UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"] @classmethod - def INPUT_TYPES(s): - return {"required": - {"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}), - "scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}), - "upscale_method": (s.upscale_methods,), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SamplerLCMUpscale", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01), + io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1), + io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS), + ], + outputs=[io.Sampler.Output()], + ) - FUNCTION = "get_sampler" - - def get_sampler(self, scale_ratio, scale_steps, upscale_method): + @classmethod + def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput: if scale_steps < 0: scale_steps = None sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method}) - return (sampler, ) + return io.NodeOutput(sampler) -from comfy.k_diffusion.sampling import to_d -import comfy.model_patcher @torch.no_grad() def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): @@ -82,30 +86,36 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No return x -class SamplerEulerCFGpp: +class SamplerEulerCFGpp(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"version": (["regular", "alternative"],),} - } - RETURN_TYPES = ("SAMPLER",) - # CATEGORY = "sampling/custom_sampling/samplers" - CATEGORY = "_for_testing" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SamplerEulerCFGpp", + display_name="SamplerEulerCFG++", + category="_for_testing", # "sampling/custom_sampling/samplers" + inputs=[ + io.Combo.Input("version", options=["regular", "alternative"]), + ], + outputs=[io.Sampler.Output()], + is_experimental=True, + ) - FUNCTION = "get_sampler" - - def get_sampler(self, version): + @classmethod + def execute(cls, version) -> io.NodeOutput: if version == "alternative": sampler = comfy.samplers.KSAMPLER(sample_euler_pp) else: sampler = comfy.samplers.ksampler("euler_cfg_pp") - return (sampler, ) + return io.NodeOutput(sampler) -NODE_CLASS_MAPPINGS = { - "SamplerLCMUpscale": SamplerLCMUpscale, - "SamplerEulerCFGpp": SamplerEulerCFGpp, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SamplerEulerCFGpp": "SamplerEulerCFG++", -} +class AdvancedSamplersExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SamplerLCMUpscale, + SamplerEulerCFGpp, + ] + +async def comfy_entrypoint() -> AdvancedSamplersExtension: + return AdvancedSamplersExtension() diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index 25b21b1b8..f27ae7da8 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -1,4 +1,8 @@ import torch +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + def project(v0, v1): v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3]) @@ -6,22 +10,45 @@ def project(v0, v1): v0_orthogonal = v0 - v0_parallel return v0_parallel, v0_orthogonal -class APG: +class APG(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}), - "norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}), - "momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}), - } - } - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - CATEGORY = "sampling/custom_sampling" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="APG", + display_name="Adaptive Projected Guidance", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Float.Input( + "eta", + default=1.0, + min=-10.0, + max=10.0, + step=0.01, + tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.", + ), + io.Float.Input( + "norm_threshold", + default=5.0, + min=0.0, + max=50.0, + step=0.1, + tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.", + ), + io.Float.Input( + "momentum", + default=0.0, + min=-5.0, + max=1.0, + step=0.01, + tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.", + ), + ], + outputs=[io.Model.Output()], + ) - def patch(self, model, eta, norm_threshold, momentum): + @classmethod + def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput: running_avg = 0 prev_sigma = None @@ -65,12 +92,15 @@ class APG: m = model.clone() m.set_model_sampler_pre_cfg_function(pre_cfg_function) - return (m,) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "APG": APG, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "APG": "Adaptive Projected Guidance", -} +class ApgExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + APG, + ] + +async def comfy_entrypoint() -> ApgExtension: + return ApgExtension() diff --git a/comfy_extras/nodes_attention_multiply.py b/comfy_extras/nodes_attention_multiply.py index 4747eb395..c0e494c2a 100644 --- a/comfy_extras/nodes_attention_multiply.py +++ b/comfy_extras/nodes_attention_multiply.py @@ -1,3 +1,7 @@ +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + def attention_multiply(attn, model, q, k, v, out): m = model.clone() @@ -16,57 +20,71 @@ def attention_multiply(attn, model, q, k, v, out): return m -class UNetSelfAttentionMultiply: +class UNetSelfAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetSelfAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, model, q, k, v, out): + @classmethod + def execute(cls, model, q, k, v, out) -> io.NodeOutput: m = attention_multiply("attn1", model, q, k, v, out) - return (m, ) + return io.NodeOutput(m) -class UNetCrossAttentionMultiply: + +class UNetCrossAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetCrossAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, model, q, k, v, out): + @classmethod + def execute(cls, model, q, k, v, out) -> io.NodeOutput: m = attention_multiply("attn2", model, q, k, v, out) - return (m, ) + return io.NodeOutput(m) -class CLIPAttentionMultiply: + +class CLIPAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip": ("CLIP",), - "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CLIPAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Clip.Input("clip"), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Clip.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, clip, q, k, v, out): + @classmethod + def execute(cls, clip, q, k, v, out) -> io.NodeOutput: m = clip.clone() sd = m.patcher.model_state_dict() @@ -79,23 +97,28 @@ class CLIPAttentionMultiply: m.add_patches({key: (None,)}, 0.0, v) if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"): m.add_patches({key: (None,)}, 0.0, out) - return (m, ) + return io.NodeOutput(m) -class UNetTemporalAttentionMultiply: + +class UNetTemporalAttentionMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="UNetTemporalAttentionMultiply", + category="_for_testing/attention_experiments", + inputs=[ + io.Model.Input("model"), + io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[io.Model.Output()], + is_experimental=True, + ) - CATEGORY = "_for_testing/attention_experiments" - - def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal): + @classmethod + def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput: m = model.clone() sd = model.model_state_dict() @@ -110,11 +133,18 @@ class UNetTemporalAttentionMultiply: m.add_patches({k: (None,)}, 0.0, cross_temporal) else: m.add_patches({k: (None,)}, 0.0, cross_structural) - return (m, ) + return io.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "UNetSelfAttentionMultiply": UNetSelfAttentionMultiply, - "UNetCrossAttentionMultiply": UNetCrossAttentionMultiply, - "CLIPAttentionMultiply": CLIPAttentionMultiply, - "UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply, -} + +class AttentionMultiplyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + UNetSelfAttentionMultiply, + UNetCrossAttentionMultiply, + CLIPAttentionMultiply, + UNetTemporalAttentionMultiply, + ] + +async def comfy_entrypoint() -> AttentionMultiplyExtension: + return AttentionMultiplyExtension() From eb39019daae96128ee848d0b7837ede299518a7c Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 Aug 2025 05:06:13 +0300 Subject: [PATCH 66/71] [V3] convert Google Veo API node to the V3 schema (#9272) * convert Google Veo API node to the V3 schema * use own full io.Schema for Veo3VideoGenerationNode * fixed typo * use auth_kwargs instead of auth_token/comfy_api_key --- comfy_api_nodes/nodes_veo2.py | 332 +++++++++++++++++++--------------- 1 file changed, 190 insertions(+), 142 deletions(-) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index e25dab2f5..251aecd42 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,17 +1,18 @@ -import io import logging import base64 import aiohttp import torch +from io import BytesIO from typing import Optional +from typing_extensions import override -from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api.input_impl.video_types import VideoFromFile from comfy_api_nodes.apis import ( VeoGenVidRequest, VeoGenVidResponse, VeoGenVidPollRequest, - VeoGenVidPollResponse + VeoGenVidPollResponse, ) from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -22,7 +23,7 @@ from comfy_api_nodes.apis.client import ( from comfy_api_nodes.apinode_utils import ( downscale_image_tensor, - tensor_to_base64_string + tensor_to_base64_string, ) AVERAGE_DURATION_VIDEO_GEN = 32 @@ -50,7 +51,7 @@ def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optiona return None -class VeoVideoGenerationNode(ComfyNodeABC): +class VeoVideoGenerationNode(comfy_io.ComfyNode): """ Generates videos from text prompts using Google's Veo API. @@ -59,101 +60,93 @@ class VeoVideoGenerationNode(ComfyNodeABC): """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text description of the video", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="VeoVideoGenerationNode", + display_name="Google Veo 2 Video Generation", + category="api node/video/Veo", + description="Generates videos from text prompts using Google's Veo 2 API", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", ), - "aspect_ratio": ( - IO.COMBO, - { - "options": ["16:9", "9:16"], - "default": "16:9", - "tooltip": "Aspect ratio of the output video", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Negative text prompt to guide what to avoid in the video", - }, + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + optional=True, ), - "duration_seconds": ( - IO.INT, - { - "default": 5, - "min": 5, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "Duration of the output video in seconds", - }, + comfy_io.Int.Input( + "duration_seconds", + default=5, + min=5, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds", + optional=True, ), - "enhance_prompt": ( - IO.BOOLEAN, - { - "default": True, - "tooltip": "Whether to enhance the prompt with AI assistance", - } + comfy_io.Boolean.Input( + "enhance_prompt", + default=True, + tooltip="Whether to enhance the prompt with AI assistance", + optional=True, ), - "person_generation": ( - IO.COMBO, - { - "options": ["ALLOW", "BLOCK"], - "default": "ALLOW", - "tooltip": "Whether to allow generating people in the video", - }, + comfy_io.Combo.Input( + "person_generation", + options=["ALLOW", "BLOCK"], + default="ALLOW", + tooltip="Whether to allow generating people in the video", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFF, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "Seed for video generation (0 for random)", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, ), - "image": (IO.IMAGE, { - "default": None, - "tooltip": "Optional reference image to guide video generation", - }), - "model": ( - IO.COMBO, - { - "options": ["veo-2.0-generate-001"], - "default": "veo-2.0-generate-001", - "tooltip": "Veo 2 model to use for video generation", - }, + comfy_io.Image.Input( + "image", + tooltip="Optional reference image to guide video generation", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + comfy_io.Combo.Input( + "model", + options=["veo-2.0-generate-001"], + default="veo-2.0-generate-001", + tooltip="Veo 2 model to use for video generation", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - RETURN_TYPES = (IO.VIDEO,) - FUNCTION = "generate_video" - CATEGORY = "api node/video/Veo" - DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API" - API_NODE = True - - async def generate_video( - self, + @classmethod + async def execute( + cls, prompt, aspect_ratio="16:9", negative_prompt="", @@ -164,8 +157,6 @@ class VeoVideoGenerationNode(ComfyNodeABC): image=None, model="veo-2.0-generate-001", generate_audio=False, - unique_id: Optional[str] = None, - **kwargs, ): # Prepare the instances for the request instances = [] @@ -202,6 +193,10 @@ class VeoVideoGenerationNode(ComfyNodeABC): if "veo-3.0" in model: parameters["generateAudio"] = generate_audio + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } # Initial request to start video generation initial_operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -214,7 +209,7 @@ class VeoVideoGenerationNode(ComfyNodeABC): instances=instances, parameters=parameters ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) initial_response = await initial_operation.execute() @@ -248,10 +243,10 @@ class VeoVideoGenerationNode(ComfyNodeABC): request=VeoGenVidPollRequest( operationName=operation_name ), - auth_kwargs=kwargs, + auth_kwargs=auth, poll_interval=5.0, result_url_extractor=get_video_url_from_response, - node_id=unique_id, + node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) @@ -304,10 +299,10 @@ class VeoVideoGenerationNode(ComfyNodeABC): logging.info("Video generation completed successfully") # Convert video data to BytesIO object - video_io = io.BytesIO(video_data) + video_io = BytesIO(video_data) # Return VideoFromFile object - return (VideoFromFile(video_io),) + return comfy_io.NodeOutput(VideoFromFile(video_io)) class Veo3VideoGenerationNode(VeoVideoGenerationNode): @@ -323,51 +318,104 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): """ @classmethod - def INPUT_TYPES(s): - parent_input = super().INPUT_TYPES() - - # Update model options for Veo 3 - parent_input["optional"]["model"] = ( - IO.COMBO, - { - "options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"], - "default": "veo-3.0-generate-001", - "tooltip": "Veo 3 model to use for video generation", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="Veo3VideoGenerationNode", + display_name="Google Veo 3 Video Generation", + category="api node/video/Veo", + description="Generates videos from text prompts using Google's Veo 3 API", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", + ), + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + optional=True, + ), + comfy_io.Int.Input( + "duration_seconds", + default=8, + min=8, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)", + optional=True, + ), + comfy_io.Boolean.Input( + "enhance_prompt", + default=True, + tooltip="Whether to enhance the prompt with AI assistance", + optional=True, + ), + comfy_io.Combo.Input( + "person_generation", + options=["ALLOW", "BLOCK"], + default="ALLOW", + tooltip="Whether to allow generating people in the video", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation (0 for random)", + optional=True, + ), + comfy_io.Image.Input( + "image", + tooltip="Optional reference image to guide video generation", + optional=True, + ), + comfy_io.Combo.Input( + "model", + options=["veo-3.0-generate-001", "veo-3.0-fast-generate-001"], + default="veo-3.0-generate-001", + tooltip="Veo 3 model to use for video generation", + optional=True, + ), + comfy_io.Boolean.Input( + "generate_audio", + default=False, + tooltip="Generate audio for the video. Supported by all Veo 3 models.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) - # Add generateAudio parameter - parent_input["optional"]["generate_audio"] = ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Generate audio for the video. Supported by all Veo 3 models.", - } - ) - # Update duration constraints for Veo 3 (only 8 seconds supported) - parent_input["optional"]["duration_seconds"] = ( - IO.INT, - { - "default": 8, - "min": 8, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)", - }, - ) +class VeoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + VeoVideoGenerationNode, + Veo3VideoGenerationNode, + ] - return parent_input - - -# Register the nodes -NODE_CLASS_MAPPINGS = { - "VeoVideoGenerationNode": VeoVideoGenerationNode, - "Veo3VideoGenerationNode": Veo3VideoGenerationNode, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "VeoVideoGenerationNode": "Google Veo 2 Video Generation", - "Veo3VideoGenerationNode": "Google Veo 3 Video Generation", -} +async def comfy_entrypoint() -> VeoExtension: + return VeoExtension() From 7ed73d12d13c2c389e0469c46c2db635a7d74278 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 Aug 2025 05:06:51 +0300 Subject: [PATCH 67/71] [V3] convert Ideogram API nodes to the V3 schema (#9278) * convert Ideogram API nodes to the V3 schema * use auth_kwargs instead of auth_token/comfy_api_key --- comfy_api_nodes/nodes_ideogram.py | 536 ++++++++++++++---------------- 1 file changed, 257 insertions(+), 279 deletions(-) diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index db24e6da4..d28895f3e 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -1,8 +1,8 @@ -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from inspect import cleandoc +from io import BytesIO +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io as comfy_io from PIL import Image import numpy as np -import io import torch from comfy_api_nodes.apis import ( IdeogramGenerateRequest, @@ -246,90 +246,81 @@ def display_image_urls_on_node(image_urls, node_id): PromptServer.instance.send_progress_text(urls_text, node_id) -class IdeogramV1(ComfyNodeABC): - """ - Generates images using the Ideogram V1 model. - """ - - def __init__(self): - pass +class IdeogramV1(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="IdeogramV1", + display_name="Ideogram V1", + category="api node/image/Ideogram", + description="Generates images using the Ideogram V1 model.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "turbo": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)", - } + comfy_io.Boolean.Input( + "turbo", + default=False, + tooltip="Whether to use turbo mode (faster generation, potentially lower quality)", ), - }, - "optional": { - "aspect_ratio": ( - IO.COMBO, - { - "options": list(V1_V2_RATIO_MAP.keys()), - "default": "1:1", - "tooltip": "The aspect ratio for image generation.", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=list(V1_V2_RATIO_MAP.keys()), + default="1:1", + tooltip="The aspect ratio for image generation.", + optional=True, ), - "magic_prompt_option": ( - IO.COMBO, - { - "options": ["AUTO", "ON", "OFF"], - "default": "AUTO", - "tooltip": "Determine if MagicPrompt should be used in generation", - }, + comfy_io.Combo.Input( + "magic_prompt_option", + options=["AUTO", "ON", "OFF"], + default="AUTO", + tooltip="Determine if MagicPrompt should be used in generation", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "step": 1, - "control_after_generate": True, - "display": "number", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "negative_prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Description of what to exclude from the image", - }, + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Description of what to exclude from the image", + optional=True, ), - "num_images": ( - IO.INT, - {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + comfy_io.Int.Input( + "num_images", + default=1, + min=1, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt, turbo=False, aspect_ratio="1:1", @@ -337,13 +328,15 @@ class IdeogramV1(ComfyNodeABC): seed=0, negative_prompt="", num_images=1, - unique_id=None, - **kwargs, ): # Determine the model based on turbo setting aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) model = "V_1_TURBO" if turbo else "V_1" + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/ideogram/generate", @@ -364,7 +357,7 @@ class IdeogramV1(ComfyNodeABC): negative_prompt=negative_prompt if negative_prompt else None, ) ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response = await operation.execute() @@ -377,93 +370,85 @@ class IdeogramV1(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") - display_image_urls_on_node(image_urls, unique_id) - return (await download_and_process_images(image_urls),) + display_image_urls_on_node(image_urls, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_and_process_images(image_urls)) -class IdeogramV2(ComfyNodeABC): - """ - Generates images using the Ideogram V2 model. - """ - - def __init__(self): - pass +class IdeogramV2(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="IdeogramV2", + display_name="Ideogram V2", + category="api node/image/Ideogram", + description="Generates images using the Ideogram V2 model.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation", ), - "turbo": ( - IO.BOOLEAN, - { - "default": False, - "tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)", - } + comfy_io.Boolean.Input( + "turbo", + default=False, + tooltip="Whether to use turbo mode (faster generation, potentially lower quality)", ), - }, - "optional": { - "aspect_ratio": ( - IO.COMBO, - { - "options": list(V1_V2_RATIO_MAP.keys()), - "default": "1:1", - "tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=list(V1_V2_RATIO_MAP.keys()), + default="1:1", + tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.", + optional=True, ), - "resolution": ( - IO.COMBO, - { - "options": list(V1_V1_RES_MAP.keys()), - "default": "Auto", - "tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.", - }, + comfy_io.Combo.Input( + "resolution", + options=list(V1_V1_RES_MAP.keys()), + default="Auto", + tooltip="The resolution for image generation. " + "If not set to AUTO, this overrides the aspect_ratio setting.", + optional=True, ), - "magic_prompt_option": ( - IO.COMBO, - { - "options": ["AUTO", "ON", "OFF"], - "default": "AUTO", - "tooltip": "Determine if MagicPrompt should be used in generation", - }, + comfy_io.Combo.Input( + "magic_prompt_option", + options=["AUTO", "ON", "OFF"], + default="AUTO", + tooltip="Determine if MagicPrompt should be used in generation", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "step": 1, - "control_after_generate": True, - "display": "number", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "style_type": ( - IO.COMBO, - { - "options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"], - "default": "NONE", - "tooltip": "Style type for generation (V2 only)", - }, + comfy_io.Combo.Input( + "style_type", + options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"], + default="NONE", + tooltip="Style type for generation (V2 only)", + optional=True, ), - "negative_prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Description of what to exclude from the image", - }, + comfy_io.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Description of what to exclude from the image", + optional=True, ), - "num_images": ( - IO.INT, - {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + comfy_io.Int.Input( + "num_images", + default=1, + min=1, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), #"color_palette": ( # IO.STRING, @@ -473,22 +458,20 @@ class IdeogramV2(ComfyNodeABC): # "tooltip": "Color palette preset name or hex colors with weights", # }, #), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt, turbo=False, aspect_ratio="1:1", @@ -499,8 +482,6 @@ class IdeogramV2(ComfyNodeABC): negative_prompt="", num_images=1, color_palette="", - unique_id=None, - **kwargs, ): aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) resolution = V1_V1_RES_MAP.get(resolution, None) @@ -517,6 +498,10 @@ class IdeogramV2(ComfyNodeABC): else: final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/ideogram/generate", @@ -540,7 +525,7 @@ class IdeogramV2(ComfyNodeABC): color_palette=color_palette if color_palette else None, ) ), - auth_kwargs=kwargs, + auth_kwargs=auth, ) response = await operation.execute() @@ -553,108 +538,99 @@ class IdeogramV2(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") - display_image_urls_on_node(image_urls, unique_id) - return (await download_and_process_images(image_urls),) + display_image_urls_on_node(image_urls, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_and_process_images(image_urls)) -class IdeogramV3(ComfyNodeABC): - """ - Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask. - """ - def __init__(self): - pass +class IdeogramV3(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation or editing", - }, + def define_schema(cls): + return comfy_io.Schema( + node_id="IdeogramV3", + display_name="Ideogram V3", + category="api node/image/Ideogram", + description="Generates images using the Ideogram V3 model. " + "Supports both regular image generation from text prompts and image editing with mask.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt for the image generation or editing", ), - }, - "optional": { - "image": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional reference image for image editing.", - }, + comfy_io.Image.Input( + "image", + tooltip="Optional reference image for image editing.", + optional=True, ), - "mask": ( - IO.MASK, - { - "default": None, - "tooltip": "Optional mask for inpainting (white areas will be replaced)", - }, + comfy_io.Mask.Input( + "mask", + tooltip="Optional mask for inpainting (white areas will be replaced)", + optional=True, ), - "aspect_ratio": ( - IO.COMBO, - { - "options": list(V3_RATIO_MAP.keys()), - "default": "1:1", - "tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=list(V3_RATIO_MAP.keys()), + default="1:1", + tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.", + optional=True, ), - "resolution": ( - IO.COMBO, - { - "options": V3_RESOLUTIONS, - "default": "Auto", - "tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.", - }, + comfy_io.Combo.Input( + "resolution", + options=V3_RESOLUTIONS, + default="Auto", + tooltip="The resolution for image generation. " + "If not set to Auto, this overrides the aspect_ratio setting.", + optional=True, ), - "magic_prompt_option": ( - IO.COMBO, - { - "options": ["AUTO", "ON", "OFF"], - "default": "AUTO", - "tooltip": "Determine if MagicPrompt should be used in generation", - }, + comfy_io.Combo.Input( + "magic_prompt_option", + options=["AUTO", "ON", "OFF"], + default="AUTO", + tooltip="Determine if MagicPrompt should be used in generation", + optional=True, ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2147483647, - "step": 1, - "control_after_generate": True, - "display": "number", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "num_images": ( - IO.INT, - {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"}, + comfy_io.Int.Input( + "num_images", + default=1, + min=1, + max=8, + step=1, + display_mode=comfy_io.NumberDisplay.number, + optional=True, ), - "rendering_speed": ( - IO.COMBO, - { - "options": ["BALANCED", "TURBO", "QUALITY"], - "default": "BALANCED", - "tooltip": "Controls the trade-off between generation speed and quality", - }, + comfy_io.Combo.Input( + "rendering_speed", + options=["BALANCED", "TURBO", "QUALITY"], + default="BALANCED", + tooltip="Controls the trade-off between generation speed and quality", + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + ) - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/Ideogram" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, + @classmethod + async def execute( + cls, prompt, image=None, mask=None, @@ -664,9 +640,11 @@ class IdeogramV3(ComfyNodeABC): seed=0, num_images=1, rendering_speed="BALANCED", - unique_id=None, - **kwargs, ): + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } # Check if both image and mask are provided for editing mode if image is not None and mask is not None: # Edit mode @@ -686,7 +664,7 @@ class IdeogramV3(ComfyNodeABC): # Process image img_np = (input_tensor.numpy() * 255).astype(np.uint8) img = Image.fromarray(img_np) - img_byte_arr = io.BytesIO() + img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) img_binary = img_byte_arr @@ -695,7 +673,7 @@ class IdeogramV3(ComfyNodeABC): # Process mask - white areas will be replaced mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) - mask_byte_arr = io.BytesIO() + mask_byte_arr = BytesIO() mask_img.save(mask_byte_arr, format="PNG") mask_byte_arr.seek(0) mask_binary = mask_byte_arr @@ -729,7 +707,7 @@ class IdeogramV3(ComfyNodeABC): "mask": mask_binary, }, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) elif image is not None or mask is not None: @@ -770,7 +748,7 @@ class IdeogramV3(ComfyNodeABC): response_model=IdeogramGenerateResponse, ), request=gen_request, - auth_kwargs=kwargs, + auth_kwargs=auth, ) # Execute the operation and process response @@ -784,18 +762,18 @@ class IdeogramV3(ComfyNodeABC): if not image_urls: raise Exception("No image URLs were generated in the response") - display_image_urls_on_node(image_urls, unique_id) - return (await download_and_process_images(image_urls),) + display_image_urls_on_node(image_urls, cls.hidden.unique_id) + return comfy_io.NodeOutput(await download_and_process_images(image_urls)) -NODE_CLASS_MAPPINGS = { - "IdeogramV1": IdeogramV1, - "IdeogramV2": IdeogramV2, - "IdeogramV3": IdeogramV3, -} +class IdeogramExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + IdeogramV1, + IdeogramV2, + IdeogramV3, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "IdeogramV1": "Ideogram V1", - "IdeogramV2": "Ideogram V2", - "IdeogramV3": "Ideogram V3", -} +async def comfy_entrypoint() -> IdeogramExtension: + return IdeogramExtension() From f7bd5e58dd03e799e02f6851b84b51e14ad0da7b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 21 Aug 2025 20:18:04 -0700 Subject: [PATCH 68/71] Make it easier to implement future qwen controlnets. (#9485) --- comfy/controlnet.py | 4 ++-- comfy/ldm/qwen_image/model.py | 16 +++++++++++++--- comfy/model_detection.py | 2 ++ 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 988acdb57..6cb69dcdf 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -236,11 +236,11 @@ class ControlNet(ControlBase): self.cond_hint = None compression_ratio = self.compression_ratio if self.vae is not None: - compression_ratio *= self.vae.downscale_ratio + compression_ratio *= self.vae.spacial_compression_encode() else: if self.latent_format is not None: raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.") - self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") + self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center") self.cond_hint = self.preprocess_image(self.cond_hint) if self.vae is not None: loaded_models = comfy.model_management.loaded_models(only_currently_used=True) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 2503583cb..d0e39833a 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -293,6 +293,7 @@ class QwenImageTransformer2DModel(nn.Module): guidance_embeds: bool = False, axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), image_model=None, + final_layer=True, dtype=None, device=None, operations=None, @@ -300,6 +301,7 @@ class QwenImageTransformer2DModel(nn.Module): super().__init__() self.dtype = dtype self.patch_size = patch_size + self.in_channels = in_channels self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim @@ -329,9 +331,9 @@ class QwenImageTransformer2DModel(nn.Module): for _ in range(num_layers) ]) - self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) - self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) - self.gradient_checkpointing = False + if final_layer: + self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations) + self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) def process_img(self, x, index=0, h_offset=0, w_offset=0): bs, c, t, h, w = x.shape @@ -362,6 +364,7 @@ class QwenImageTransformer2DModel(nn.Module): guidance: torch.Tensor = None, ref_latents=None, transformer_options={}, + control=None, **kwargs ): timestep = timesteps @@ -443,6 +446,13 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = out["img"] encoder_hidden_states = out["txt"] + if control is not None: # Controlnet + control_i = control.get("input") + if i < len(control_i): + add = control_i[i] + if add is not None: + hidden_states += add + hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 2bec0541e..0caff53e0 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -492,6 +492,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image dit_config = {} dit_config["image_model"] = "qwen_image" + dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1] + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') return dit_config if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: From ff57793659702d502506047445f0972b10b6b9fe Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 21 Aug 2025 21:53:11 -0700 Subject: [PATCH 69/71] Support InstantX Qwen controlnet. (#9488) --- comfy/controlnet.py | 13 +++++ comfy/ldm/qwen_image/controlnet.py | 77 ++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 comfy/ldm/qwen_image/controlnet.py diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 6cb69dcdf..e3dfedf55 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -36,6 +36,7 @@ import comfy.ldm.cascade.controlnet import comfy.cldm.mmdit import comfy.ldm.hydit.controlnet import comfy.ldm.flux.controlnet +import comfy.ldm.qwen_image.controlnet import comfy.cldm.dit_embedder from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -582,6 +583,15 @@ def load_controlnet_flux_instantx(sd, model_options={}): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control +def load_controlnet_qwen_instantx(sd, model_options={}): + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) + control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_load_state_dict(control_model, sd) + latent_format = comfy.latent_formats.Wan21() + extra_conds = [] + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) + return control + def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) @@ -655,8 +665,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format else: return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet + elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data: + return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options) elif "controlnet_x_embedder.weight" in controlnet_data: return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) + elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options) diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py new file mode 100644 index 000000000..92ac3cf0a --- /dev/null +++ b/comfy/ldm/qwen_image/controlnet.py @@ -0,0 +1,77 @@ +import torch +import math + +from .model import QwenImageTransformer2DModel + + +class QwenImageControlNetModel(QwenImageTransformer2DModel): + def __init__( + self, + extra_condition_channels=0, + dtype=None, + device=None, + operations=None, + **kwargs + ): + super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) + self.main_model_double = 60 + + # controlnet_blocks + self.controlnet_blocks = torch.nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype)) + self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype) + + def forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + ref_latents=None, + hint=None, + transformer_options={}, + **kwargs + ): + timestep = timesteps + encoder_hidden_states = context + encoder_hidden_states_mask = attention_mask + + hidden_states, img_ids, orig_shape = self.process_img(x) + hint, _, _ = self.process_img(hint) + + txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + del ids, txt_ids, img_ids + + hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + self.time_text_embed(timestep, hidden_states) + if guidance is None + else self.time_text_embed(timestep, guidance, hidden_states) + ) + + repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks)) + + controlnet_block_samples = () + for i, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat + + return {"input": controlnet_block_samples[:self.main_model_double]} From 497d41fb500668635aa4a782549e9a0caa24375e Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 Aug 2025 20:50:35 +0300 Subject: [PATCH 70/71] feat(api-nodes): change "OpenAI Chat" display name to "OpenAI ChatGPT" (#9443) --- comfy_api_nodes/nodes_openai.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 674c9ede0..e3b81de75 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -998,7 +998,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "OpenAIDalle2": "OpenAI DALL·E 2", "OpenAIDalle3": "OpenAI DALL·E 3", "OpenAIGPTImage1": "OpenAI GPT Image 1", - "OpenAIChatNode": "OpenAI Chat", - "OpenAIInputFiles": "OpenAI Chat Input Files", - "OpenAIChatConfig": "OpenAI Chat Advanced Options", + "OpenAIChatNode": "OpenAI ChatGPT", + "OpenAIInputFiles": "OpenAI ChatGPT Input Files", + "OpenAIChatConfig": "OpenAI ChatGPT Advanced Options", } From 050c67323c33f6543309d4f09df706ec8c9a1389 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 Aug 2025 20:51:14 +0300 Subject: [PATCH 71/71] feat(api-nodes): add copy button to Gemini Chat node (#9440) --- comfy_api_nodes/nodes_gemini.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ba4167a50..78c402a7a 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -5,7 +5,10 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer from __future__ import annotations +import json +import time import os +import uuid from enum import Enum from typing import Optional, Literal @@ -350,7 +353,27 @@ class GeminiNode(ComfyNodeABC): # Get result output output_text = self.get_text_from_response(response) if unique_id and output_text: - PromptServer.instance.send_progress_text(output_text, node_id=unique_id) + # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. + render_spec = { + "node_id": unique_id, + "component": "ChatHistoryWidget", + "props": { + "history": json.dumps( + [ + { + "prompt": prompt, + "response": output_text, + "response_id": str(uuid.uuid4()), + "timestamp": time.time(), + } + ] + ), + }, + } + PromptServer.instance.send_sync( + "display_component", + render_spec, + ) return (output_text or "Empty response from Gemini model...",)