Merge upstream/master, keep local README.md

This commit is contained in:
GitHub Actions 2025-10-10 00:33:16 +00:00
commit ff7d9dd57d
26 changed files with 1143 additions and 865 deletions

View File

@ -123,16 +123,30 @@ def move_weight_functions(m, device):
return memory return memory
class LowVramPatch: class LowVramPatch:
def __init__(self, key, patches): def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key self.key = key
self.patches = patches self.patches = patches
self.convert_func = convert_func
self.set_func = set_func
def __call__(self, weight): def __call__(self, weight):
intermediate_dtype = weight.dtype intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32 intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key)) out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
def get_key_weight(model, key): def get_key_weight(model, key):
set_func = None set_func = None
@ -657,13 +671,15 @@ class ModelPatcher:
if force_patch_weights: if force_patch_weights:
self.patch_weight_to_device(weight_key) self.patch_weight_to_device(weight_key)
else: else:
m.weight_function = [LowVramPatch(weight_key, self.patches)] _, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
patch_counter += 1 patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
if force_patch_weights: if force_patch_weights:
self.patch_weight_to_device(bias_key) self.patch_weight_to_device(bias_key)
else: else:
m.bias_function = [LowVramPatch(bias_key, self.patches)] _, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
patch_counter += 1 patch_counter += 1
cast_weight = True cast_weight = True
@ -825,10 +841,12 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to) module_mem += move_weight_functions(m, device_to)
if lowvram_possible: if lowvram_possible:
if weight_key in self.patches: if weight_key in self.patches:
m.weight_function.append(LowVramPatch(weight_key, self.patches)) _, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1 patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
m.bias_function.append(LowVramPatch(bias_key, self.patches)) _, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1 patch_counter += 1
cast_weight = True cast_weight = True

View File

@ -416,8 +416,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
else: else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update: if inplace_update:
self.weight.data.copy_(weight) self.weight.data.copy_(weight)
else: else:

View File

@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from comfy_api.latest._io import _IO as io #noqa: F401 from . import _io as io
from comfy_api.latest._ui import _UI as ui #noqa: F401 from . import _ui as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple from comfy_execution.progress import get_progress_state, PreviewImageTuple
@ -114,6 +114,8 @@ if TYPE_CHECKING:
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest) ComfyAPISync = create_sync_class(ComfyAPI_latest)
comfy_io = io # create the new alias for io
__all__ = [ __all__ = [
"ComfyAPI", "ComfyAPI",
"ComfyAPISync", "ComfyAPISync",
@ -121,4 +123,7 @@ __all__ = [
"InputImpl", "InputImpl",
"Types", "Types",
"ComfyExtension", "ComfyExtension",
"io",
"comfy_io",
"ui",
] ]

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Union from typing import Optional, Union, IO
import io import io
import av import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
@ -23,7 +23,7 @@ class VideoInput(ABC):
@abstractmethod @abstractmethod
def save_to( def save_to(
self, self,
path: str, path: Union[str, IO[bytes]],
format: VideoContainer = VideoContainer.AUTO, format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None metadata: Optional[dict] = None

View File

@ -1582,78 +1582,78 @@ class _UIOutput(ABC):
... ...
class _IO: __all__ = [
FolderType = FolderType "FolderType",
UploadType = UploadType "UploadType",
RemoteOptions = RemoteOptions "RemoteOptions",
NumberDisplay = NumberDisplay "NumberDisplay",
comfytype = staticmethod(comfytype) "comfytype",
Custom = staticmethod(Custom) "Custom",
Input = Input "Input",
WidgetInput = WidgetInput "WidgetInput",
Output = Output "Output",
ComfyTypeI = ComfyTypeI "ComfyTypeI",
ComfyTypeIO = ComfyTypeIO "ComfyTypeIO",
#---------------------------------
# Supported Types # Supported Types
Boolean = Boolean "Boolean",
Int = Int "Int",
Float = Float "Float",
String = String "String",
Combo = Combo "Combo",
MultiCombo = MultiCombo "MultiCombo",
Image = Image "Image",
WanCameraEmbedding = WanCameraEmbedding "WanCameraEmbedding",
Webcam = Webcam "Webcam",
Mask = Mask "Mask",
Latent = Latent "Latent",
Conditioning = Conditioning "Conditioning",
Sampler = Sampler "Sampler",
Sigmas = Sigmas "Sigmas",
Noise = Noise "Noise",
Guider = Guider "Guider",
Clip = Clip "Clip",
ControlNet = ControlNet "ControlNet",
Vae = Vae "Vae",
Model = Model "Model",
ClipVision = ClipVision "ClipVision",
ClipVisionOutput = ClipVisionOutput "ClipVisionOutput",
AudioEncoder = AudioEncoder "AudioEncoder",
AudioEncoderOutput = AudioEncoderOutput "AudioEncoderOutput",
StyleModel = StyleModel "StyleModel",
Gligen = Gligen "Gligen",
UpscaleModel = UpscaleModel "UpscaleModel",
Audio = Audio "Audio",
Video = Video "Video",
SVG = SVG "SVG",
LoraModel = LoraModel "LoraModel",
LossMap = LossMap "LossMap",
Voxel = Voxel "Voxel",
Mesh = Mesh "Mesh",
Hooks = Hooks "Hooks",
HookKeyframes = HookKeyframes "HookKeyframes",
TimestepsRange = TimestepsRange "TimestepsRange",
LatentOperation = LatentOperation "LatentOperation",
FlowControl = FlowControl "FlowControl",
Accumulation = Accumulation "Accumulation",
Load3DCamera = Load3DCamera "Load3DCamera",
Load3D = Load3D "Load3D",
Load3DAnimation = Load3DAnimation "Load3DAnimation",
Photomaker = Photomaker "Photomaker",
Point = Point "Point",
FaceAnalysis = FaceAnalysis "FaceAnalysis",
BBOX = BBOX "BBOX",
SEGS = SEGS "SEGS",
AnyType = AnyType "AnyType",
MultiType = MultiType "MultiType",
#--------------------------------- # Other classes
HiddenHolder = HiddenHolder "HiddenHolder",
Hidden = Hidden "Hidden",
NodeInfoV1 = NodeInfoV1 "NodeInfoV1",
NodeInfoV3 = NodeInfoV3 "NodeInfoV3",
Schema = Schema "Schema",
ComfyNode = ComfyNode "ComfyNode",
NodeOutput = NodeOutput "NodeOutput",
add_to_dict_v1 = staticmethod(add_to_dict_v1) "add_to_dict_v1",
add_to_dict_v3 = staticmethod(add_to_dict_v3) "add_to_dict_v3",
]

View File

@ -449,15 +449,16 @@ class PreviewText(_UIOutput):
return {"text": (self.value,)} return {"text": (self.value,)}
class _UI: __all__ = [
SavedResult = SavedResult "SavedResult",
SavedImages = SavedImages "SavedImages",
SavedAudios = SavedAudios "SavedAudios",
ImageSaveHelper = ImageSaveHelper "ImageSaveHelper",
AudioSaveHelper = AudioSaveHelper "AudioSaveHelper",
PreviewImage = PreviewImage "PreviewImage",
PreviewMask = PreviewMask "PreviewMask",
PreviewAudio = PreviewAudio "PreviewAudio",
PreviewVideo = PreviewVideo "PreviewVideo",
PreviewUI3D = PreviewUI3D "PreviewUI3D",
PreviewText = PreviewText "PreviewText",
]

View File

@ -269,7 +269,7 @@ def tensor_to_bytesio(
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns: Returns:
Named BytesIO object containing the image data. Named BytesIO object containing the image data, with pointer set to the start of buffer.
""" """
if not mime_type: if not mime_type:
mime_type = "image/png" mime_type = "image/png"
@ -431,7 +431,7 @@ async def upload_video_to_comfyapi(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
) )
except Exception as e: except Exception as e:
logging.error(f"Error getting video duration: {e}") logging.error("Error getting video duration: %s", str(e))
raise ValueError(f"Could not verify video duration from source: {e}") from e raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}" upload_mime_type = f"video/{container.value.lower()}"

View File

@ -98,7 +98,7 @@ import io
import os import os
import socket import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple from typing import Type, Optional, Any, TypeVar, Generic, Callable
from enum import Enum from enum import Enum
import json import json
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
@ -175,7 +175,7 @@ class ApiClient:
max_retries: int = 3, max_retries: int = 3,
retry_delay: float = 1.0, retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0, retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[Tuple[int, ...]] = None, retry_status_codes: Optional[tuple[int, ...]] = None,
session: Optional[aiohttp.ClientSession] = None, session: Optional[aiohttp.ClientSession] = None,
): ):
self.base_url = base_url self.base_url = base_url
@ -199,9 +199,9 @@ class ApiClient:
@staticmethod @staticmethod
def _create_json_payload_args( def _create_json_payload_args(
data: Optional[Dict[str, Any]] = None, data: Optional[dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"json": data, "json": data,
"headers": headers, "headers": headers,
@ -209,11 +209,11 @@ class ApiClient:
def _create_form_data_args( def _create_form_data_args(
self, self,
data: Dict[str, Any] | None, data: dict[str, Any] | None,
files: Dict[str, Any] | None, files: dict[str, Any] | None,
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
multipart_parser: Callable | None = None, multipart_parser: Callable | None = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
if headers and "Content-Type" in headers: if headers and "Content-Type" in headers:
del headers["Content-Type"] del headers["Content-Type"]
@ -254,9 +254,9 @@ class ApiClient:
@staticmethod @staticmethod
def _create_urlencoded_form_data_args( def _create_urlencoded_form_data_args(
data: Dict[str, Any], data: dict[str, Any],
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
headers = headers or {} headers = headers or {}
headers["Content-Type"] = "application/x-www-form-urlencoded" headers["Content-Type"] = "application/x-www-form-urlencoded"
return { return {
@ -264,7 +264,7 @@ class ApiClient:
"headers": headers, "headers": headers,
} }
def get_headers(self) -> Dict[str, str]: def get_headers(self) -> dict[str, str]:
"""Get headers for API requests, including authentication if available""" """Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"} headers = {"Content-Type": "application/json", "Accept": "application/json"}
@ -275,7 +275,7 @@ class ApiClient:
return headers return headers
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]: async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
""" """
Check connectivity to determine if network issues are local or server-related. Check connectivity to determine if network issues are local or server-related.
@ -316,14 +316,14 @@ class ApiClient:
self, self,
method: str, method: str,
path: str, path: str,
params: Optional[Dict[str, Any]] = None, params: Optional[dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None, data: Optional[dict[str, Any]] = None,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
content_type: str = "application/json", content_type: str = "application/json",
multipart_parser: Callable | None = None, multipart_parser: Callable | None = None,
retry_count: int = 0, # Used internally for tracking retries retry_count: int = 0, # Used internally for tracking retries
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Make an HTTP request to the API with automatic retries for transient errors. Make an HTTP request to the API with automatic retries for transient errors.
@ -359,10 +359,10 @@ class ApiClient:
if params: if params:
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
logging.debug(f"[DEBUG] Request Headers: {request_headers}") logging.debug("[DEBUG] Request Headers: %s", request_headers)
logging.debug(f"[DEBUG] Files: {files}") logging.debug("[DEBUG] Files: %s", files)
logging.debug(f"[DEBUG] Params: {params}") logging.debug("[DEBUG] Params: %s", params)
logging.debug(f"[DEBUG] Data: {data}") logging.debug("[DEBUG] Data: %s", data)
if content_type == "application/x-www-form-urlencoded": if content_type == "application/x-www-form-urlencoded":
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers) payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
@ -485,7 +485,7 @@ class ApiClient:
retry_delay: Initial delay between retries in seconds retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry retry_backoff_factor: Multiplier for the delay after each retry
""" """
headers: Dict[str, str] = {} headers: dict[str, str] = {}
skip_auto_headers: set[str] = set() skip_auto_headers: set[str] = set()
if content_type: if content_type:
headers["Content-Type"] = content_type headers["Content-Type"] = content_type
@ -558,7 +558,7 @@ class ApiClient:
*req_meta, *req_meta,
retry_count: int, retry_count: int,
response_content: dict | str = "", response_content: dict | str = "",
) -> Dict[str, Any]: ) -> dict[str, Any]:
status_code = exc.status status_code = exc.status
if status_code == 401: if status_code == 401:
user_friendly = "Unauthorized: Please login first to use this node." user_friendly = "Unauthorized: Please login first to use this node."
@ -592,9 +592,9 @@ class ApiClient:
error_message=f"HTTP Error {exc.status}", error_message=f"HTTP Error {exc.status}",
) )
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})") logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
if response_content: if response_content:
logging.debug(f"[DEBUG] Response content: {response_content}") logging.debug("[DEBUG] Response content: %s", response_content)
# Retry if eligible # Retry if eligible
if status_code in self.retry_status_codes and retry_count < self.max_retries: if status_code in self.retry_status_codes and retry_count < self.max_retries:
@ -659,7 +659,7 @@ class ApiEndpoint(Generic[T, R]):
method: HttpMethod, method: HttpMethod,
request_model: Type[T], request_model: Type[T],
response_model: Type[R], response_model: Type[R],
query_params: Optional[Dict[str, Any]] = None, query_params: Optional[dict[str, Any]] = None,
): ):
"""Initialize an API endpoint definition. """Initialize an API endpoint definition.
@ -684,11 +684,11 @@ class SynchronousOperation(Generic[T, R]):
self, self,
endpoint: ApiEndpoint[T, R], endpoint: ApiEndpoint[T, R],
request: T, request: T,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
api_base: str | None = None, api_base: str | None = None,
auth_token: Optional[str] = None, auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None, comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
timeout: float = 7200.0, timeout: float = 7200.0,
verify_ssl: bool = True, verify_ssl: bool = True,
content_type: str = "application/json", content_type: str = "application/json",
@ -729,7 +729,7 @@ class SynchronousOperation(Generic[T, R]):
) )
try: try:
request_dict: Optional[Dict[str, Any]] request_dict: Optional[dict[str, Any]]
if isinstance(self.request, EmptyRequest): if isinstance(self.request, EmptyRequest):
request_dict = None request_dict = None
else: else:
@ -738,11 +738,9 @@ class SynchronousOperation(Generic[T, R]):
if isinstance(v, Enum): if isinstance(v, Enum):
request_dict[k] = v.value request_dict[k] = v.value
logging.debug( logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
) logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
response_json = await client.request( response_json = await client.request(
self.endpoint.method.value, self.endpoint.method.value,
@ -757,11 +755,11 @@ class SynchronousOperation(Generic[T, R]):
logging.debug("=" * 50) logging.debug("=" * 50)
logging.debug("[DEBUG] RESPONSE DETAILS:") logging.debug("[DEBUG] RESPONSE DETAILS:")
logging.debug("[DEBUG] Status Code: 200 (Success)") logging.debug("[DEBUG] Status Code: 200 (Success)")
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}") logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
logging.debug("=" * 50) logging.debug("=" * 50)
parsed_response = self.endpoint.response_model.model_validate(response_json) parsed_response = self.endpoint.response_model.model_validate(response_json)
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}") logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
return parsed_response return parsed_response
finally: finally:
if owns_client: if owns_client:
@ -784,14 +782,14 @@ class PollingOperation(Generic[T, R]):
poll_endpoint: ApiEndpoint[EmptyRequest, R], poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str], completed_statuses: list[str],
failed_statuses: list[str], failed_statuses: list[str],
status_extractor: Callable[[R], str], status_extractor: Callable[[R], Optional[str]],
progress_extractor: Callable[[R], float] | None = None, progress_extractor: Callable[[R], Optional[float]] | None = None,
result_url_extractor: Callable[[R], str] | None = None, result_url_extractor: Callable[[R], Optional[str]] | None = None,
request: Optional[T] = None, request: Optional[T] = None,
api_base: str | None = None, api_base: str | None = None,
auth_token: Optional[str] = None, auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None, comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
poll_interval: float = 5.0, poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call max_retries: int = 3, # Max retries per individual API call
@ -877,7 +875,7 @@ class PollingOperation(Generic[T, R]):
status = TaskStatus.PENDING status = TaskStatus.PENDING
for poll_count in range(1, self.max_poll_attempts + 1): for poll_count in range(1, self.max_poll_attempts + 1):
try: try:
logging.debug(f"[DEBUG] Polling attempt #{poll_count}") logging.debug("[DEBUG] Polling attempt #%s", poll_count)
request_dict = ( request_dict = (
None if self.request is None else self.request.model_dump(exclude_none=True) None if self.request is None else self.request.model_dump(exclude_none=True)
@ -885,10 +883,13 @@ class PollingOperation(Generic[T, R]):
if poll_count == 1: if poll_count == 1:
logging.debug( logging.debug(
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}" "[DEBUG] Poll Request: %s %s",
self.poll_endpoint.method.value,
self.poll_endpoint.path,
) )
logging.debug( logging.debug(
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}" "[DEBUG] Poll Request Data: %s",
json.dumps(request_dict, indent=2) if request_dict else "None",
) )
# Query task status # Query task status
@ -903,7 +904,7 @@ class PollingOperation(Generic[T, R]):
# Check if task is complete # Check if task is complete
status = self._check_task_status(response_obj) status = self._check_task_status(response_obj)
logging.debug(f"[DEBUG] Task Status: {status}") logging.debug("[DEBUG] Task Status: %s", status)
# If progress extractor is provided, extract progress # If progress extractor is provided, extract progress
if self.progress_extractor: if self.progress_extractor:
@ -917,7 +918,7 @@ class PollingOperation(Generic[T, R]):
result_url = self.result_url_extractor(response_obj) result_url = self.result_url_extractor(response_obj)
if result_url: if result_url:
message = f"Result URL: {result_url}" message = f"Result URL: {result_url}"
logging.debug(f"[DEBUG] {message}") logging.debug("[DEBUG] %s", message)
self._display_text_on_node(message) self._display_text_on_node(message)
self.final_response = response_obj self.final_response = response_obj
if self.progress_extractor: if self.progress_extractor:
@ -925,7 +926,7 @@ class PollingOperation(Generic[T, R]):
return self.final_response return self.final_response
if status == TaskStatus.FAILED: if status == TaskStatus.FAILED:
message = f"Task failed: {json.dumps(resp)}" message = f"Task failed: {json.dumps(resp)}"
logging.error(f"[DEBUG] {message}") logging.error("[DEBUG] %s", message)
raise Exception(message) raise Exception(message)
logging.debug("[DEBUG] Task still pending, continuing to poll...") logging.debug("[DEBUG] Task still pending, continuing to poll...")
# Task pending wait # Task pending wait
@ -939,7 +940,12 @@ class PollingOperation(Generic[T, R]):
raise Exception( raise Exception(
f"Polling aborted after {consecutive_errors} network errors: {str(e)}" f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
) from e ) from e
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e)) logging.warning(
"Network error (%s/%s): %s",
consecutive_errors,
max_consecutive_errors,
str(e),
)
await asyncio.sleep(self.poll_interval) await asyncio.sleep(self.poll_interval)
except Exception as e: except Exception as e:
# For other errors, increment count and potentially abort # For other errors, increment count and potentially abort
@ -949,10 +955,13 @@ class PollingOperation(Generic[T, R]):
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e ) from e
logging.error(f"[DEBUG] Polling error: {str(e)}") logging.error("[DEBUG] Polling error: %s", str(e))
logging.warning( logging.warning(
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " "Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
f"Will retry in {self.poll_interval} seconds." poll_count,
self.max_poll_attempts,
str(e),
self.poll_interval,
) )
await asyncio.sleep(self.poll_interval) await asyncio.sleep(self.poll_interval)

View File

@ -0,0 +1,100 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -21,7 +21,7 @@ def get_log_directory():
try: try:
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
except Exception as e: except Exception as e:
logger.error(f"Error creating API log directory {log_dir}: {e}") logger.error("Error creating API log directory %s: %s", log_dir, str(e))
# Fallback to base temp directory if sub-directory creation fails # Fallback to base temp directory if sub-directory creation fails
return base_temp_dir return base_temp_dir
return log_dir return log_dir
@ -122,9 +122,9 @@ def log_request_response(
try: try:
with open(filepath, "w", encoding="utf-8") as f: with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content)) f.write("\n".join(log_content))
logger.debug(f"API log saved to: {filepath}") logger.debug("API log saved to: %s", filepath)
except Exception as e: except Exception as e:
logger.error(f"Error writing API log to {filepath}: {e}") logger.error("Error writing API log to %s: %s", filepath, str(e))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -296,7 +296,7 @@ def validate_video_result_response(response) -> None:
"""Validates that the Kling task result contains a video.""" """Validates that the Kling task result contains a video."""
if not is_valid_video_response(response): if not is_valid_video_response(response):
error_msg = f"Kling task {response.data.task_id} succeeded but no video data found in response." error_msg = f"Kling task {response.data.task_id} succeeded but no video data found in response."
logging.error(f"Error: {error_msg}.\nResponse: {response}") logging.error("Error: %s.\nResponse: %s", error_msg, response)
raise Exception(error_msg) raise Exception(error_msg)
@ -304,7 +304,7 @@ def validate_image_result_response(response) -> None:
"""Validates that the Kling task result contains an image.""" """Validates that the Kling task result contains an image."""
if not is_valid_image_response(response): if not is_valid_image_response(response):
error_msg = f"Kling task {response.data.task_id} succeeded but no image data found in response." error_msg = f"Kling task {response.data.task_id} succeeded but no image data found in response."
logging.error(f"Error: {error_msg}.\nResponse: {response}") logging.error("Error: %s.\nResponse: %s", error_msg, response)
raise Exception(error_msg) raise Exception(error_msg)

View File

@ -500,7 +500,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
raise Exception( raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}" f"No video was found in the response. Full response: {file_result.model_dump()}"
) )
logging.info(f"Generated video URL: {file_url}") logging.info("Generated video URL: %s", file_url)
if cls.hidden.unique_id: if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"): if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"

View File

@ -237,7 +237,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
audio_stream = None audio_stream = None
for stream in input_container.streams: for stream in input_container.streams:
logging.info(f"Found stream: type={stream.type}, class={type(stream)}") logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
if isinstance(stream, av.VideoStream): if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters # Create output video stream with same parameters
video_stream = output_container.add_stream( video_stream = output_container.add_stream(
@ -247,7 +247,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
video_stream.height = stream.height video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p" video_stream.pix_fmt = "yuv420p"
logging.info( logging.info(
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps" "Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
) )
elif isinstance(stream, av.AudioStream): elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters # Create output audio stream with same parameters
@ -256,9 +256,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
) )
audio_stream.sample_rate = stream.sample_rate audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout audio_stream.layout = stream.layout
logging.info( logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
)
# Calculate target frame count that's divisible by 16 # Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate fps = input_container.streams.video[0].average_rate
@ -288,9 +286,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
for packet in video_stream.encode(): for packet in video_stream.encode():
output_container.mux(packet) output_container.mux(packet)
logging.info( logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
f"Encoded {frame_count} video frames (target: {target_frames})"
)
# Decode and re-encode audio frames # Decode and re-encode audio frames
if audio_stream: if audio_stream:
@ -308,7 +304,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
for packet in audio_stream.encode(): for packet in audio_stream.encode():
output_container.mux(packet) output_container.mux(packet)
logging.info(f"Encoded {audio_frame_count} audio frames") logging.info("Encoded %s audio frames", audio_frame_count)
# Close containers # Close containers
output_container.close() output_container.close()

View File

@ -8,30 +8,17 @@ from __future__ import annotations
from io import BytesIO from io import BytesIO
import logging import logging
from typing import Optional, TypeVar from typing import Optional, TypeVar
from enum import Enum
import numpy as np
import torch import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api.latest import ComfyExtension, comfy_io
from comfy_api.input_impl import VideoFromFile
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apinode_utils import ( from comfy_api_nodes.apinode_utils import (
download_url_to_video_output, download_url_to_video_output,
tensor_to_bytesio, tensor_to_bytesio,
) )
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import pika_defs
PikaBodyGenerate22C2vGenerate22PikascenesPost,
PikaBodyGenerate22I2vGenerate22I2vPost,
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
PikaBodyGenerate22T2vGenerate22T2vPost,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
PikaGenerateResponse,
PikaVideoResponse,
)
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.apis.client import (
ApiEndpoint, ApiEndpoint,
EmptyRequest, EmptyRequest,
@ -55,116 +42,36 @@ PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos" PATH_VIDEO_GET = "/proxy/pika/videos"
class PikaDurationEnum(int, Enum): async def execute_task(
integer_5 = 5 initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
integer_10 = 10
class PikaResolutionEnum(str, Enum):
field_1080p = "1080p"
field_720p = "720p"
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaApiError(Exception):
"""Exception for Pika API errors."""
pass
def is_valid_video_response(response: PikaVideoResponse) -> bool:
"""Check if the video response is valid."""
return hasattr(response, "url") and response.url is not None
def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
"""Check if the initial response is valid."""
return hasattr(response, "video_id") and response.video_id is not None
async def poll_for_task_status(
task_id: str,
auth_kwargs: Optional[dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None, node_id: Optional[str] = None,
) -> PikaGenerateResponse: ) -> comfy_io.NodeOutput:
polling_operation = PollingOperation( task_id = (await initial_operation.execute()).video_id
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
poll_endpoint=ApiEndpoint( poll_endpoint=ApiEndpoint(
path=f"{PATH_VIDEO_GET}/{task_id}", path=f"{PATH_VIDEO_GET}/{task_id}",
method=HttpMethod.GET, method=HttpMethod.GET,
request_model=EmptyRequest, request_model=EmptyRequest,
response_model=PikaVideoResponse, response_model=pika_defs.PikaVideoResponse,
), ),
completed_statuses=[ completed_statuses=["finished"],
"finished",
],
failed_statuses=["failed", "cancelled"], failed_statuses=["failed", "cancelled"],
status_extractor=lambda response: ( status_extractor=lambda response: (response.status.value if response.status else None),
response.status.value if response.status else None progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
),
progress_extractor=lambda response: (
response.progress if hasattr(response, "progress") else None
),
auth_kwargs=auth_kwargs, auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: ( result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
response.url if hasattr(response, "url") else None
),
node_id=node_id, node_id=node_id,
estimated_duration=60 estimated_duration=60,
) max_poll_attempts=240,
return await polling_operation.execute() ).execute()
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
async def execute_task(
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]:
"""Executes the initial operation then polls for the task status until it is completed.
Args:
initial_operation: The initial operation to execute.
auth_kwargs: The authentication token(s) to use for the API call.
Returns:
A tuple containing the video file as a VIDEO output.
"""
initial_response = await initial_operation.execute()
if not is_valid_initial_response(initial_response):
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
logging.error(error_msg) logging.error(error_msg)
raise PikaApiError(error_msg) raise Exception(error_msg)
video_url = final_response.url
task_id = initial_response.video_id
final_response = await poll_for_task_status(task_id, auth_kwargs, node_id=node_id)
if not is_valid_video_response(final_response):
error_msg = (
f"Pika task {task_id} succeeded but no video data found in response."
)
logging.error(error_msg)
raise PikaApiError(error_msg)
video_url = str(final_response.url)
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url) logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return comfy_io.NodeOutput(await download_url_to_video_output(video_url))
return (await download_url_to_video_output(video_url),)
def get_base_inputs_types() -> list[comfy_io.Input]: def get_base_inputs_types() -> list[comfy_io.Input]:
@ -173,16 +80,12 @@ def get_base_inputs_types() -> list[comfy_io.Input]:
comfy_io.String.Input("prompt_text", multiline=True), comfy_io.String.Input("prompt_text", multiline=True),
comfy_io.String.Input("negative_prompt", multiline=True), comfy_io.String.Input("negative_prompt", multiline=True),
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True), comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
comfy_io.Combo.Input( comfy_io.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
"resolution", options=PikaResolutionEnum, default=PikaResolutionEnum.field_1080p comfy_io.Combo.Input("duration", options=[5, 10], default=5),
),
comfy_io.Combo.Input(
"duration", options=PikaDurationEnum, default=PikaDurationEnum.integer_5
),
] ]
class PikaImageToVideoV2_2(comfy_io.ComfyNode): class PikaImageToVideo(comfy_io.ComfyNode):
"""Pika 2.2 Image to Video Node.""" """Pika 2.2 Image to Video Node."""
@classmethod @classmethod
@ -215,14 +118,9 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
resolution: str, resolution: str,
duration: int, duration: int,
) -> comfy_io.NodeOutput: ) -> comfy_io.NodeOutput:
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image) image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")} pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
# Prepare non-file data
pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
@ -237,8 +135,8 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO, path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGenerate22I2vGenerate22I2vPost, request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=pika_request_data, request=pika_request_data,
files=pika_files, files=pika_files,
@ -248,7 +146,7 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode): class PikaTextToVideoNode(comfy_io.ComfyNode):
"""Pika Text2Video v2.2 Node.""" """Pika Text2Video v2.2 Node."""
@classmethod @classmethod
@ -296,10 +194,10 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_TEXT_TO_VIDEO, path=PATH_TEXT_TO_VIDEO,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGenerate22T2vGenerate22T2vPost, request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=PikaBodyGenerate22T2vGenerate22T2vPost( request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
@ -313,7 +211,7 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaScenesV2_2(comfy_io.ComfyNode): class PikaScenes(comfy_io.ComfyNode):
"""PikaScenes v2.2 Node.""" """PikaScenes v2.2 Node."""
@classmethod @classmethod
@ -389,7 +287,6 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
image_ingredient_4: Optional[torch.Tensor] = None, image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None, image_ingredient_5: Optional[torch.Tensor] = None,
) -> comfy_io.NodeOutput: ) -> comfy_io.NodeOutput:
# Convert all passed images to BytesIO
all_image_bytes_io = [] all_image_bytes_io = []
for image in [ for image in [
image_ingredient_1, image_ingredient_1,
@ -399,16 +296,14 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
image_ingredient_5, image_ingredient_5,
]: ]:
if image is not None: if image is not None:
image_bytes_io = tensor_to_bytesio(image) all_image_bytes_io.append(tensor_to_bytesio(image))
image_bytes_io.seek(0)
all_image_bytes_io.append(image_bytes_io)
pika_files = [ pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png")) ("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io) for i, image_bytes_io in enumerate(all_image_bytes_io)
] ]
pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost( pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode, ingredientsMode=ingredients_mode,
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
@ -425,8 +320,8 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_PIKASCENES, path=PATH_PIKASCENES,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost, request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=pika_request_data, request=pika_request_data,
files=pika_files, files=pika_files,
@ -477,22 +372,16 @@ class PikAdditionsNode(comfy_io.ComfyNode):
negative_prompt: str, negative_prompt: str,
seed: int, seed: int,
) -> comfy_io.NodeOutput: ) -> comfy_io.NodeOutput:
# Convert video to BytesIO
video_bytes_io = BytesIO() video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264) video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0) video_bytes_io.seek(0)
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image) image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = { pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"), "video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"), "image": ("image.png", image_bytes_io, "image/png"),
} }
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
# Prepare non-file data
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
@ -505,8 +394,8 @@ class PikAdditionsNode(comfy_io.ComfyNode):
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS, path=PATH_PIKADDITIONS,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost, request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=pika_request_data, request=pika_request_data,
files=pika_files, files=pika_files,
@ -529,11 +418,25 @@ class PikaSwapsNode(comfy_io.ComfyNode):
category="api node/video/Pika", category="api node/video/Pika",
inputs=[ inputs=[
comfy_io.Video.Input("video", tooltip="The video to swap an object in."), comfy_io.Video.Input("video", tooltip="The video to swap an object in."),
comfy_io.Image.Input("image", tooltip="The image used to replace the masked object in the video."), comfy_io.Image.Input(
comfy_io.Mask.Input("mask", tooltip="Use the mask to define areas in the video to replace"), "image",
comfy_io.String.Input("prompt_text", multiline=True), tooltip="The image used to replace the masked object in the video.",
comfy_io.String.Input("negative_prompt", multiline=True), optional=True,
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True), ),
comfy_io.Mask.Input(
"mask",
tooltip="Use the mask to define areas in the video to replace.",
optional=True,
),
comfy_io.String.Input("prompt_text", multiline=True, optional=True),
comfy_io.String.Input("negative_prompt", multiline=True, optional=True),
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
comfy_io.String.Input(
"region_to_modify",
multiline=True,
optional=True,
tooltip="Plaintext description of the object / region to modify.",
),
], ],
outputs=[comfy_io.Video.Output()], outputs=[comfy_io.Video.Output()],
hidden=[ hidden=[
@ -548,41 +451,29 @@ class PikaSwapsNode(comfy_io.ComfyNode):
async def execute( async def execute(
cls, cls,
video: VideoInput, video: VideoInput,
image: torch.Tensor, image: Optional[torch.Tensor] = None,
mask: torch.Tensor, mask: Optional[torch.Tensor] = None,
prompt_text: str, prompt_text: str = "",
negative_prompt: str, negative_prompt: str = "",
seed: int, seed: int = 0,
region_to_modify: str = "",
) -> comfy_io.NodeOutput: ) -> comfy_io.NodeOutput:
# Convert video to BytesIO
video_bytes_io = BytesIO() video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264) video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0) video_bytes_io.seek(0)
# Convert mask to binary mask with three channels
mask = torch.round(mask)
mask = mask.repeat(1, 3, 1, 1)
# Convert 3-channel binary mask to BytesIO
mask_bytes_io = BytesIO()
mask_bytes_io.write(mask.numpy().astype(np.uint8))
mask_bytes_io.seek(0)
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = { pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"), "video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
} }
if mask is not None:
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
if image is not None:
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
# Prepare non-file data pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
) )
auth = { auth = {
"auth_token": cls.hidden.auth_token_comfy_org, "auth_token": cls.hidden.auth_token_comfy_org,
@ -590,10 +481,10 @@ class PikaSwapsNode(comfy_io.ComfyNode):
} }
initial_operation = SynchronousOperation( initial_operation = SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS, path=PATH_PIKASWAPS,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost, request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=pika_request_data, request=pika_request_data,
files=pika_files, files=pika_files,
@ -616,7 +507,7 @@ class PikaffectsNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."), comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"pikaffect", options=Pikaffect, default="Cake-ify" "pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
), ),
comfy_io.String.Input("prompt_text", multiline=True), comfy_io.String.Input("prompt_text", multiline=True),
comfy_io.String.Input("negative_prompt", multiline=True), comfy_io.String.Input("negative_prompt", multiline=True),
@ -648,10 +539,10 @@ class PikaffectsNode(comfy_io.ComfyNode):
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_PIKAFFECTS, path=PATH_PIKAFFECTS,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost, request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost( request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect, pikaffect=pikaffect,
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
@ -664,7 +555,7 @@ class PikaffectsNode(comfy_io.ComfyNode):
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaStartEndFrameNode2_2(comfy_io.ComfyNode): class PikaStartEndFrameNode(comfy_io.ComfyNode):
"""PikaFrames v2.2 Node.""" """PikaFrames v2.2 Node."""
@classmethod @classmethod
@ -711,10 +602,10 @@ class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_PIKAFRAMES, path=PATH_PIKAFRAMES,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost, request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost( request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
@ -732,13 +623,13 @@ class PikaApiNodesExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [ return [
PikaImageToVideoV2_2, PikaImageToVideo,
PikaTextToVideoNodeV2_2, PikaTextToVideoNode,
PikaScenesV2_2, PikaScenes,
PikAdditionsNode, PikAdditionsNode,
PikaSwapsNode, PikaSwapsNode,
PikaffectsNode, PikaffectsNode,
PikaStartEndFrameNode2_2, PikaStartEndFrameNode,
] ]

View File

@ -172,16 +172,16 @@ async def create_generate_task(
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!") logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key subscription_key = response.jobs.subscription_key
task_uuid = response.uuid task_uuid = response.uuid
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}") logging.info("[ Rodin3D API - Submit Jobs ] UUID: %s", task_uuid)
return task_uuid, subscription_key return task_uuid, subscription_key
def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
all_done = all(job.status == JobStatus.Done for job in response.jobs) all_done = all(job.status == JobStatus.Done for job in response.jobs)
status_list = [str(job.status) for job in response.jobs] status_list = [str(job.status) for job in response.jobs]
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}") logging.info("[ Rodin3D API - CheckStatus ] Generate Status: %s", status_list)
if any(job.status == JobStatus.Failed for job in response.jobs): if any(job.status == JobStatus.Failed for job in response.jobs):
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.") logging.error("[ Rodin3D API - CheckStatus ] Generate Failed: %s, Please try again.", status_list)
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.") raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
if all_done: if all_done:
return "DONE" return "DONE"
@ -235,7 +235,7 @@ async def download_files(url_list, task_uuid):
file_path = os.path.join(save_path, file_name) file_path = os.path.join(save_path, file_name)
if file_path.endswith(".glb"): if file_path.endswith(".glb"):
model_file_path = file_path model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
max_retries = 5 max_retries = 5
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
@ -246,7 +246,7 @@ async def download_files(url_list, task_uuid):
f.write(chunk) f.write(chunk)
break break
except Exception as e: except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e))
if attempt < max_retries - 1: if attempt < max_retries - 1:
logging.info("Retrying...") logging.info("Retrying...")
await asyncio.sleep(2) await asyncio.sleep(2)

View File

@ -215,7 +215,7 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
initial_response = await initial_operation.execute() initial_response = await initial_operation.execute()
operation_name = initial_response.name operation_name = initial_response.name
logging.info(f"Veo generation started with operation name: {operation_name}") logging.info("Veo generation started with operation name: %s", operation_name)
# Define status extractor function # Define status extractor function
def status_extractor(response): def status_extractor(response):

View File

@ -1,6 +1,9 @@
import torch import torch
import comfy.utils import comfy.utils
from enum import Enum from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def resize_mask(mask, shape): def resize_mask(mask, shape):
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1) return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
return out_image, out_alpha return out_image, out_alpha
class PorterDuffImageComposite: class PorterDuffImageComposite(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="PorterDuffImageComposite",
"source": ("IMAGE",), display_name="Porter-Duff Image Composite",
"source_alpha": ("MASK",), category="mask/compositing",
"destination": ("IMAGE",), inputs=[
"destination_alpha": ("MASK",), io.Image.Input("source"),
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}), io.Mask.Input("source_alpha"),
}, io.Image.Input("destination"),
} io.Mask.Input("destination_alpha"),
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)
RETURN_TYPES = ("IMAGE", "MASK") @classmethod
FUNCTION = "composite" def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
CATEGORY = "mask/compositing"
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
out_images = [] out_images = []
out_alphas = [] out_alphas = []
@ -150,45 +157,48 @@ class PorterDuffImageComposite:
out_images.append(out_image) out_images.append(out_image)
out_alphas.append(out_alpha.squeeze(2)) out_alphas.append(out_alpha.squeeze(2))
result = (torch.stack(out_images), torch.stack(out_alphas)) return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
return result
class SplitImageWithAlpha: class SplitImageWithAlpha(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="SplitImageWithAlpha",
"image": ("IMAGE",), display_name="Split Image with Alpha",
} category="mask/compositing",
} inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)
CATEGORY = "mask/compositing" @classmethod
RETURN_TYPES = ("IMAGE", "MASK") def execute(cls, image: torch.Tensor) -> io.NodeOutput:
FUNCTION = "split_image_with_alpha"
def split_image_with_alpha(self, image: torch.Tensor):
out_images = [i[:,:,:3] for i in image] out_images = [i[:,:,:3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result
class JoinImageWithAlpha: class JoinImageWithAlpha(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return io.Schema(
"required": { node_id="JoinImageWithAlpha",
"image": ("IMAGE",), display_name="Join Image with Alpha",
"alpha": ("MASK",), category="mask/compositing",
} inputs=[
} io.Image.Input("image"),
io.Mask.Input("alpha"),
],
outputs=[io.Image.Output()],
)
CATEGORY = "mask/compositing" @classmethod
RETURN_TYPES = ("IMAGE",) def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
FUNCTION = "join_image_with_alpha"
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
batch_size = min(len(image), len(alpha)) batch_size = min(len(image), len(alpha))
out_images = [] out_images = []
@ -196,19 +206,18 @@ class JoinImageWithAlpha:
for i in range(batch_size): for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
result = (torch.stack(out_images),) return io.NodeOutput(torch.stack(out_images))
return result
NODE_CLASS_MAPPINGS = { class CompositingExtension(ComfyExtension):
"PorterDuffImageComposite": PorterDuffImageComposite, @override
"SplitImageWithAlpha": SplitImageWithAlpha, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"JoinImageWithAlpha": JoinImageWithAlpha, return [
} PorterDuffImageComposite,
SplitImageWithAlpha,
JoinImageWithAlpha,
]
NODE_DISPLAY_NAME_MAPPINGS = { async def comfy_entrypoint() -> CompositingExtension:
"PorterDuffImageComposite": "Porter-Duff Image Composite", return CompositingExtension()
"SplitImageWithAlpha": "Split Image with Alpha",
"JoinImageWithAlpha": "Join Image with Alpha",
}

View File

@ -1,60 +1,80 @@
import node_helpers import node_helpers
import comfy.utils import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodeFlux:
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"clip": ("CLIP", ), node_id="CLIPTextEncodeFlux",
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), category="advanced/conditioning/flux",
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), inputs=[
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), io.Clip.Input("clip"),
}} io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
RETURN_TYPES = ("CONDITIONING",) io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
FUNCTION = "encode" io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning/flux" @classmethod
def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput:
def encode(self, clip, clip_l, t5xxl, guidance):
tokens = clip.tokenize(clip_l) tokens = clip.tokenize(clip_l)
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
class FluxGuidance: encode = execute # TODO: remove
class FluxGuidance(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"conditioning": ("CONDITIONING", ), node_id="FluxGuidance",
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), category="advanced/conditioning/flux",
}} inputs=[
io.Conditioning.Input("conditioning"),
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "append" def execute(cls, conditioning, guidance) -> io.NodeOutput:
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, guidance):
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance}) c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
return (c, ) return io.NodeOutput(c)
append = execute # TODO: remove
class FluxDisableGuidance: class FluxDisableGuidance(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"conditioning": ("CONDITIONING", ), node_id="FluxDisableGuidance",
}} category="advanced/conditioning/flux",
description="This node completely disables the guidance embed on Flux and Flux like models",
inputs=[
io.Conditioning.Input("conditioning"),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "append" def execute(cls, conditioning) -> io.NodeOutput:
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
def append(self, conditioning):
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None}) c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
return (c, ) return io.NodeOutput(c)
append = execute # TODO: remove
PREFERED_KONTEXT_RESOLUTIONS = [ PREFERED_KONTEXT_RESOLUTIONS = [
@ -78,52 +98,73 @@ PREFERED_KONTEXT_RESOLUTIONS = [
] ]
class FluxKontextImageScale: class FluxKontextImageScale(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"image": ("IMAGE", ), return io.Schema(
}, node_id="FluxKontextImageScale",
} category="advanced/conditioning/flux",
description="This node resizes the image to one that is more optimal for flux kontext.",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "scale" def execute(cls, image) -> io.NodeOutput:
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
def scale(self, image):
width = image.shape[2] width = image.shape[2]
height = image.shape[1] height = image.shape[1]
aspect_ratio = width / height aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return (image, ) return io.NodeOutput(image)
scale = execute # TODO: remove
class FluxKontextMultiReferenceLatentMethod: class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"conditioning": ("CONDITIONING", ), node_id="FluxKontextMultiReferenceLatentMethod",
"reference_latents_method": (("offset", "index", "uxo/uno"), ), category="advanced/conditioning/flux",
}} inputs=[
io.Conditioning.Input("conditioning"),
io.Combo.Input(
"reference_latents_method",
options=["offset", "index", "uxo/uno"],
),
],
outputs=[
io.Conditioning.Output(),
],
is_experimental=True,
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "append" def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput:
EXPERIMENTAL = True
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
if "uxo" in reference_latents_method or "uso" in reference_latents_method: if "uxo" in reference_latents_method or "uso" in reference_latents_method:
reference_latents_method = "uxo" reference_latents_method = "uxo"
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method}) c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, ) return io.NodeOutput(c)
NODE_CLASS_MAPPINGS = { append = execute # TODO: remove
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance, class FluxExtension(ComfyExtension):
"FluxKontextImageScale": FluxKontextImageScale, @override
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod, async def get_node_list(self) -> list[type[io.ComfyNode]]:
} return [
CLIPTextEncodeFlux,
FluxGuidance,
FluxDisableGuidance,
FluxKontextImageScale,
FluxKontextMultiReferenceLatentMethod,
]
async def comfy_entrypoint() -> FluxExtension:
return FluxExtension()

View File

@ -2,6 +2,8 @@ import comfy.utils
import comfy_extras.nodes_post_processing import comfy_extras.nodes_post_processing
import torch import torch
import nodes import nodes
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def reshape_latent_to(target_shape, latent, repeat_batch=True): def reshape_latent_to(target_shape, latent, repeat_batch=True):
@ -13,17 +15,23 @@ def reshape_latent_to(target_shape, latent, repeat_batch=True):
return latent return latent
class LatentAdd: class LatentAdd(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} return io.Schema(
node_id="LatentAdd",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples1, samples2) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples1, samples2):
samples_out = samples1.copy() samples_out = samples1.copy()
s1 = samples1["samples"] s1 = samples1["samples"]
@ -31,19 +39,25 @@ class LatentAdd:
s2 = reshape_latent_to(s1.shape, s2) s2 = reshape_latent_to(s1.shape, s2)
samples_out["samples"] = s1 + s2 samples_out["samples"] = s1 + s2
return (samples_out,) return io.NodeOutput(samples_out)
class LatentSubtract: class LatentSubtract(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} return io.Schema(
node_id="LatentSubtract",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples1, samples2) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples1, samples2):
samples_out = samples1.copy() samples_out = samples1.copy()
s1 = samples1["samples"] s1 = samples1["samples"]
@ -51,41 +65,49 @@ class LatentSubtract:
s2 = reshape_latent_to(s1.shape, s2) s2 = reshape_latent_to(s1.shape, s2)
samples_out["samples"] = s1 - s2 samples_out["samples"] = s1 - s2
return (samples_out,) return io.NodeOutput(samples_out)
class LatentMultiply: class LatentMultiply(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples": ("LATENT",), return io.Schema(
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), node_id="LatentMultiply",
}} category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples, multiplier) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples, multiplier):
samples_out = samples.copy() samples_out = samples.copy()
s1 = samples["samples"] s1 = samples["samples"]
samples_out["samples"] = s1 * multiplier samples_out["samples"] = s1 * multiplier
return (samples_out,) return io.NodeOutput(samples_out)
class LatentInterpolate: class LatentInterpolate(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples1": ("LATENT",), return io.Schema(
"samples2": ("LATENT",), node_id="LatentInterpolate",
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), category="latent/advanced",
}} inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples1, samples2, ratio) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, ratio):
samples_out = samples1.copy() samples_out = samples1.copy()
s1 = samples1["samples"] s1 = samples1["samples"]
@ -104,19 +126,26 @@ class LatentInterpolate:
st = torch.nan_to_num(t / mt) st = torch.nan_to_num(t / mt)
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,) return io.NodeOutput(samples_out)
class LatentConcat: class LatentConcat(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}} return io.Schema(
node_id="LatentConcat",
category="latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples1, samples2, dim) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, dim):
samples_out = samples1.copy() samples_out = samples1.copy()
s1 = samples1["samples"] s1 = samples1["samples"]
@ -136,22 +165,27 @@ class LatentConcat:
dim = -3 dim = -3
samples_out["samples"] = torch.cat(c, dim=dim) samples_out["samples"] = torch.cat(c, dim=dim)
return (samples_out,) return io.NodeOutput(samples_out)
class LatentCut: class LatentCut(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"samples": ("LATENT",), return io.Schema(
"dim": (["x", "y", "t"], ), node_id="LatentCut",
"index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}), category="latent/advanced",
"amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}} inputs=[
io.Latent.Input("samples"),
io.Combo.Input("dim", options=["x", "y", "t"]),
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples, dim, index, amount) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples, dim, index, amount):
samples_out = samples.copy() samples_out = samples.copy()
s1 = samples["samples"] s1 = samples["samples"]
@ -171,19 +205,25 @@ class LatentCut:
amount = min(-index, amount) amount = min(-index, amount)
samples_out["samples"] = torch.narrow(s1, dim, index, amount) samples_out["samples"] = torch.narrow(s1, dim, index, amount)
return (samples_out,) return io.NodeOutput(samples_out)
class LatentBatch: class LatentBatch(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} return io.Schema(
node_id="LatentBatch",
category="latent/batch",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "batch" def execute(cls, samples1, samples2) -> io.NodeOutput:
CATEGORY = "latent/batch"
def batch(self, samples1, samples2):
samples_out = samples1.copy() samples_out = samples1.copy()
s1 = samples1["samples"] s1 = samples1["samples"]
s2 = samples2["samples"] s2 = samples2["samples"]
@ -192,20 +232,25 @@ class LatentBatch:
s = torch.cat((s1, s2), dim=0) s = torch.cat((s1, s2), dim=0)
samples_out["samples"] = s samples_out["samples"] = s
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,) return io.NodeOutput(samples_out)
class LatentBatchSeedBehavior: class LatentBatchSeedBehavior(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples": ("LATENT",), return io.Schema(
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}} node_id="LatentBatchSeedBehavior",
category="latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples, seed_behavior) -> io.NodeOutput:
CATEGORY = "latent/advanced"
def op(self, samples, seed_behavior):
samples_out = samples.copy() samples_out = samples.copy()
latent = samples["samples"] latent = samples["samples"]
if seed_behavior == "random": if seed_behavior == "random":
@ -215,41 +260,50 @@ class LatentBatchSeedBehavior:
batch_number = samples_out.get("batch_index", [0])[0] batch_number = samples_out.get("batch_index", [0])[0]
samples_out["batch_index"] = [batch_number] * latent.shape[0] samples_out["batch_index"] = [batch_number] * latent.shape[0]
return (samples_out,) return io.NodeOutput(samples_out)
class LatentApplyOperation: class LatentApplyOperation(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples": ("LATENT",), return io.Schema(
"operation": ("LATENT_OPERATION",), node_id="LatentApplyOperation",
}} category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Latent.Input("samples"),
io.LatentOperation.Input("operation"),
],
outputs=[
io.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",) @classmethod
FUNCTION = "op" def execute(cls, samples, operation) -> io.NodeOutput:
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, samples, operation):
samples_out = samples.copy() samples_out = samples.copy()
s1 = samples["samples"] s1 = samples["samples"]
samples_out["samples"] = operation(latent=s1) samples_out["samples"] = operation(latent=s1)
return (samples_out,) return io.NodeOutput(samples_out)
class LatentApplyOperationCFG: class LatentApplyOperationCFG(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"operation": ("LATENT_OPERATION",), node_id="LatentApplyOperationCFG",
}} category="latent/advanced/operations",
RETURN_TYPES = ("MODEL",) is_experimental=True,
FUNCTION = "patch" inputs=[
io.Model.Input("model"),
io.LatentOperation.Input("operation"),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "latent/advanced/operations" @classmethod
EXPERIMENTAL = True def execute(cls, model, operation) -> io.NodeOutput:
def patch(self, model, operation):
m = model.clone() m = model.clone()
def pre_cfg_function(args): def pre_cfg_function(args):
@ -261,21 +315,25 @@ class LatentApplyOperationCFG:
return conds_out return conds_out
m.set_model_sampler_pre_cfg_function(pre_cfg_function) m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m, ) return io.NodeOutput(m)
class LatentOperationTonemapReinhard: class LatentOperationTonemapReinhard(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), return io.Schema(
}} node_id="LatentOperationTonemapReinhard",
category="latent/advanced/operations",
is_experimental=True,
inputs=[
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[
io.LatentOperation.Output(),
],
)
RETURN_TYPES = ("LATENT_OPERATION",) @classmethod
FUNCTION = "op" def execute(cls, multiplier) -> io.NodeOutput:
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, multiplier):
def tonemap_reinhard(latent, **kwargs): def tonemap_reinhard(latent, **kwargs):
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None] latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
normalized_latent = latent / latent_vector_magnitude normalized_latent = latent / latent_vector_magnitude
@ -291,39 +349,27 @@ class LatentOperationTonemapReinhard:
new_magnitude *= top new_magnitude *= top
return normalized_latent * new_magnitude return normalized_latent * new_magnitude
return (tonemap_reinhard,) return io.NodeOutput(tonemap_reinhard)
class LatentOperationSharpen: class LatentOperationSharpen(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"sharpen_radius": ("INT", { node_id="LatentOperationSharpen",
"default": 9, category="latent/advanced/operations",
"min": 1, is_experimental=True,
"max": 31, inputs=[
"step": 1 io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
}), io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
"sigma": ("FLOAT", { io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
"default": 1.0, ],
"min": 0.1, outputs=[
"max": 10.0, io.LatentOperation.Output(),
"step": 0.1 ],
}), )
"alpha": ("FLOAT", {
"default": 0.1,
"min": 0.0,
"max": 5.0,
"step": 0.01
}),
}}
RETURN_TYPES = ("LATENT_OPERATION",) @classmethod
FUNCTION = "op" def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput:
CATEGORY = "latent/advanced/operations"
EXPERIMENTAL = True
def op(self, sharpen_radius, sigma, alpha):
def sharpen(latent, **kwargs): def sharpen(latent, **kwargs):
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None] luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
normalized_latent = latent / luminance normalized_latent = latent / luminance
@ -340,19 +386,27 @@ class LatentOperationSharpen:
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
return luminance * sharpened return luminance * sharpened
return (sharpen,) return io.NodeOutput(sharpen)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd, class LatentExtension(ComfyExtension):
"LatentSubtract": LatentSubtract, @override
"LatentMultiply": LatentMultiply, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"LatentInterpolate": LatentInterpolate, return [
"LatentConcat": LatentConcat, LatentAdd,
"LatentCut": LatentCut, LatentSubtract,
"LatentBatch": LatentBatch, LatentMultiply,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior, LatentInterpolate,
"LatentApplyOperation": LatentApplyOperation, LatentConcat,
"LatentApplyOperationCFG": LatentApplyOperationCFG, LatentCut,
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard, LatentBatch,
"LatentOperationSharpen": LatentOperationSharpen, LatentBatchSeedBehavior,
} LatentApplyOperation,
LatentApplyOperationCFG,
LatentOperationTonemapReinhard,
LatentOperationSharpen,
]
async def comfy_entrypoint() -> LatentExtension:
return LatentExtension()

View File

@ -5,6 +5,8 @@ import folder_paths
import os import os
import logging import logging
from enum import Enum from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
CLAMP_QUANTILE = 0.99 CLAMP_QUANTILE = 0.99
@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu() output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
return output_sd return output_sd
class LoraSave: class LoraSave(io.ComfyNode):
def __init__(self): @classmethod
self.output_dir = folder_paths.get_output_directory() def define_schema(cls):
return io.Schema(
node_id="LoraSave",
display_name="Extract and Save Lora",
category="_for_testing",
inputs=[
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())),
io.Boolean.Input("bias_diff", default=True),
io.Model.Input(
"model_diff",
tooltip="The ModelSubtract output to be converted to a lora.",
optional=True,
),
io.Clip.Input(
"text_encoder_diff",
tooltip="The CLIPSubtract output to be converted to a lora.",
optional=True,
),
],
is_experimental=True,
is_output_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
"lora_type": (tuple(LORA_TYPES.keys()),),
"bias_diff": ("BOOLEAN", {"default": True}),
},
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
if model_diff is None and text_encoder_diff is None: if model_diff is None and text_encoder_diff is None:
return {} return io.NodeOutput()
lora_type = LORA_TYPES.get(lora_type) lora_type = LORA_TYPES.get(lora_type)
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
output_sd = {} output_sd = {}
if model_diff is not None: if model_diff is not None:
@ -108,12 +118,16 @@ class LoraSave:
output_checkpoint = os.path.join(full_output_folder, output_checkpoint) output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None) comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
return {} return io.NodeOutput()
NODE_CLASS_MAPPINGS = {
"LoraSave": LoraSave
}
NODE_DISPLAY_NAME_MAPPINGS = { class LoraSaveExtension(ComfyExtension):
"LoraSave": "Extract and Save Lora" @override
} async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoraSave,
]
async def comfy_entrypoint() -> LoraSaveExtension:
return LoraSaveExtension()

View File

@ -1,24 +1,33 @@
from typing_extensions import override
import comfy.utils import comfy.utils
from comfy_api.latest import ComfyExtension, io
class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] class PatchModelAddDownscale(io.ComfyNode):
UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return io.Schema(
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}), node_id="PatchModelAddDownscale",
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), display_name="PatchModelAddDownscale (Kohya Deep Shrink)",
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), category="model_patches/unet",
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), inputs=[
"downscale_after_skip": ("BOOLEAN", {"default": True}), io.Model.Input("model"),
"downscale_method": (s.upscale_methods,), io.Int.Input("block_number", default=3, min=1, max=32, step=1),
"upscale_method": (s.upscale_methods,), io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
}} io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
RETURN_TYPES = ("MODEL",) io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
FUNCTION = "patch" io.Boolean.Input("downscale_after_skip", default=True),
io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
],
outputs=[
io.Model.Output(),
],
)
CATEGORY = "model_patches/unet" @classmethod
def execute(cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method) -> io.NodeOutput:
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
model_sampling = model.get_model_object("model_sampling") model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent) sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent) sigma_end = model_sampling.percent_to_sigma(end_percent)
@ -41,13 +50,21 @@ class PatchModelAddDownscale:
else: else:
m.set_model_input_block_patch(input_block_patch) m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch) m.set_model_output_block_patch(output_block_patch)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale": PatchModelAddDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling # Sampling
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)", "PatchModelAddDownscale": "",
} }
class ModelDownscaleExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PatchModelAddDownscale,
]
async def comfy_entrypoint() -> ModelDownscaleExtension:
return ModelDownscaleExtension()

View File

@ -3,64 +3,83 @@ import comfy.sd
import comfy.model_management import comfy.model_management
import nodes import nodes
import torch import torch
import comfy_extras.nodes_slg from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_slg import SkipLayerGuidanceDiT
class TripleCLIPLoader: class TripleCLIPLoader(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), ) return io.Schema(
}} node_id="TripleCLIPLoader",
RETURN_TYPES = ("CLIP",) category="advanced/loaders",
FUNCTION = "load_clip" description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
inputs=[
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/loaders" @classmethod
def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput:
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,) return io.NodeOutput(clip)
load_clip = execute # TODO: remove
class EmptySD3LatentImage: class EmptySD3LatentImage(io.ComfyNode):
def __init__(self): @classmethod
self.device = comfy.model_management.intermediate_device() def define_schema(cls):
return io.Schema(
node_id="EmptySD3LatentImage",
category="latent/sd3",
inputs=[
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), return io.NodeOutput({"samples":latent})
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/sd3" generate = execute # TODO: remove
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
class CLIPTextEncodeSD3: class CLIPTextEncodeSD3(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"clip": ("CLIP", ), node_id="CLIPTextEncodeSD3",
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), category="advanced/conditioning",
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), inputs=[
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), io.Clip.Input("clip"),
"empty_padding": (["none", "empty_prompt"], ) io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
}} io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
RETURN_TYPES = ("CONDITIONING",) io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
FUNCTION = "encode" io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning" @classmethod
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput:
def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
no_padding = empty_padding == "none" no_padding = empty_padding == "none"
tokens = clip.tokenize(clip_g) tokens = clip.tokenize(clip_g)
@ -82,57 +101,112 @@ class CLIPTextEncodeSD3:
tokens["l"] += empty["l"] tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]): while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"] tokens["g"] += empty["g"]
return (clip.encode_from_tokens_scheduled(tokens), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
encode = execute # TODO: remove
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): class ControlNetApplySD3(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> io.Schema:
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="ControlNetApplySD3",
"control_net": ("CONTROL_NET", ), display_name="Apply Controlnet with VAE",
"vae": ("VAE", ), category="conditioning/controlnet",
"image": ("IMAGE", ), inputs=[
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), io.Conditioning.Input("positive"),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), io.Conditioning.Input("negative"),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) io.ControlNet.Input("control_net"),
}} io.Vae.Input("vae"),
CATEGORY = "conditioning/controlnet" io.Image.Input("image"),
DEPRECATED = True io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
is_deprecated=True,
)
@classmethod
def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput:
if strength == 0:
return io.NodeOutput(positive, negative)
control_hint = image.movedim(-1, 1)
cnets = {}
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent),
vae=vae, extra_concat=[])
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
d['control'] = c_net
d['control_apply_to_uncond'] = False
n = [t[0], d]
c.append(n)
out.append(c)
return io.NodeOutput(out[0], out[1])
apply_controlnet = execute # TODO: remove
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT): class SkipLayerGuidanceSD3(io.ComfyNode):
''' '''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Experimental implementation by Dango233@StabilityAI. Experimental implementation by Dango233@StabilityAI.
''' '''
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), node_id="SkipLayerGuidanceSD3",
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), category="advanced/guidance",
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) inputs=[
}} io.Model.Input("model"),
RETURN_TYPES = ("MODEL",) io.String.Input("layers", default="7, 8, 9", multiline=False),
FUNCTION = "skip_guidance_sd3" io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
)
CATEGORY = "advanced/guidance" @classmethod
def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput:
return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent): skip_guidance_sd3 = execute # TODO: remove
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
NODE_CLASS_MAPPINGS = { class SD3Extension(ComfyExtension):
"TripleCLIPLoader": TripleCLIPLoader, @override
"EmptySD3LatentImage": EmptySD3LatentImage, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"CLIPTextEncodeSD3": CLIPTextEncodeSD3, return [
"ControlNetApplySD3": ControlNetApplySD3, TripleCLIPLoader,
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3, EmptySD3LatentImage,
} CLIPTextEncodeSD3,
ControlNetApplySD3,
SkipLayerGuidanceSD3,
]
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling async def comfy_entrypoint() -> SD3Extension:
"ControlNetApplySD3": "Apply Controlnet with VAE", return SD3Extension()
}

View File

@ -1,33 +1,40 @@
import comfy.model_patcher import comfy.model_patcher
import comfy.samplers import comfy.samplers
import re import re
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class SkipLayerGuidanceDiT: class SkipLayerGuidanceDiT(io.ComfyNode):
''' '''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Original experimental implementation for SD3 by Dango233@StabilityAI. Original experimental implementation for SD3 by Dango233@StabilityAI.
''' '''
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), node_id="SkipLayerGuidanceDiT",
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), category="advanced/guidance",
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), is_experimental=True,
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), inputs=[
"rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), io.Model.Input("model"),
}} io.String.Input("double_layers", default="7, 8, 9"),
RETURN_TYPES = ("MODEL",) io.String.Input("single_layers", default="7, 8, 9"),
FUNCTION = "skip_guidance" io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
EXPERIMENTAL = True io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model." @classmethod
def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0) -> io.NodeOutput:
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
# check if layer is comma separated integers # check if layer is comma separated integers
def skip(args, extra_args): def skip(args, extra_args):
return args return args
@ -43,7 +50,7 @@ class SkipLayerGuidanceDiT:
single_layers = [int(i) for i in single_layers] single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0: if len(double_layers) == 0 and len(single_layers) == 0:
return (model, ) return io.NodeOutput(model)
def post_cfg_function(args): def post_cfg_function(args):
model = args["model"] model = args["model"]
@ -76,29 +83,36 @@ class SkipLayerGuidanceDiT:
m = model.clone() m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function) m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m, ) return io.NodeOutput(m)
class SkipLayerGuidanceDiTSimple: skip_guidance = execute # TODO: remove
class SkipLayerGuidanceDiTSimple(io.ComfyNode):
''' '''
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass. Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
''' '''
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), node_id="SkipLayerGuidanceDiTSimple",
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), category="advanced/guidance",
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.",
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), is_experimental=True,
}} inputs=[
RETURN_TYPES = ("MODEL",) io.Model.Input("model"),
FUNCTION = "skip_guidance" io.String.Input("double_layers", default="7, 8, 9"),
EXPERIMENTAL = True io.String.Input("single_layers", default="7, 8, 9"),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Model.Output(),
],
)
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass." @classmethod
def execute(cls, model, start_percent, end_percent, double_layers="", single_layers="") -> io.NodeOutput:
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
def skip(args, extra_args): def skip(args, extra_args):
return args return args
@ -113,7 +127,7 @@ class SkipLayerGuidanceDiTSimple:
single_layers = [int(i) for i in single_layers] single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0: if len(double_layers) == 0 and len(single_layers) == 0:
return (model, ) return io.NodeOutput(model)
def calc_cond_batch_function(args): def calc_cond_batch_function(args):
x = args["input"] x = args["input"]
@ -144,9 +158,19 @@ class SkipLayerGuidanceDiTSimple:
m = model.clone() m = model.clone()
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function) m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = { skip_guidance = execute # TODO: remove
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
} class SkipLayerGuidanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SkipLayerGuidanceDiT,
SkipLayerGuidanceDiTSimple,
]
async def comfy_entrypoint() -> SkipLayerGuidanceExtension:
return SkipLayerGuidanceExtension()

View File

@ -4,6 +4,8 @@ from comfy import model_management
import torch import torch
import comfy.utils import comfy.utils
import folder_paths import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
try: try:
from spandrel_extra_arches import EXTRA_REGISTRY from spandrel_extra_arches import EXTRA_REGISTRY
@ -13,17 +15,23 @@ try:
except: except:
pass pass
class UpscaleModelLoader: class UpscaleModelLoader(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ), return io.Schema(
}} node_id="UpscaleModelLoader",
RETURN_TYPES = ("UPSCALE_MODEL",) display_name="Load Upscale Model",
FUNCTION = "load_model" category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")),
],
outputs=[
io.UpscaleModel.Output(),
],
)
CATEGORY = "loaders" @classmethod
def execute(cls, model_name) -> io.NodeOutput:
def load_model(self, model_name):
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name) model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True) sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
@ -33,21 +41,29 @@ class UpscaleModelLoader:
if not isinstance(out, ImageModelDescriptor): if not isinstance(out, ImageModelDescriptor):
raise Exception("Upscale model must be a single-image model.") raise Exception("Upscale model must be a single-image model.")
return (out, ) return io.NodeOutput(out)
load_model = execute # TODO: remove
class ImageUpscaleWithModel: class ImageUpscaleWithModel(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "upscale_model": ("UPSCALE_MODEL",), return io.Schema(
"image": ("IMAGE",), node_id="ImageUpscaleWithModel",
}} display_name="Upscale Image (using Model)",
RETURN_TYPES = ("IMAGE",) category="image/upscaling",
FUNCTION = "upscale" inputs=[
io.UpscaleModel.Input("upscale_model"),
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
],
)
CATEGORY = "image/upscaling" @classmethod
def execute(cls, upscale_model, image) -> io.NodeOutput:
def upscale(self, upscale_model, image):
device = model_management.get_torch_device() device = model_management.get_torch_device()
memory_required = model_management.module_size(upscale_model.model) memory_required = model_management.module_size(upscale_model.model)
@ -75,9 +91,19 @@ class ImageUpscaleWithModel:
upscale_model.to("cpu") upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return (s,) return io.NodeOutput(s)
NODE_CLASS_MAPPINGS = { upscale = execute # TODO: remove
"UpscaleModelLoader": UpscaleModelLoader,
"ImageUpscaleWithModel": ImageUpscaleWithModel
} class UpscaleModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
UpscaleModelLoader,
ImageUpscaleWithModel,
]
async def comfy_entrypoint() -> UpscaleModelExtension:
return UpscaleModelExtension()

View File

@ -2027,7 +2027,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"DiffControlNetLoader": "Load ControlNet Model (diff)", "DiffControlNetLoader": "Load ControlNet Model (diff)",
"StyleModelLoader": "Load Style Model", "StyleModelLoader": "Load Style Model",
"CLIPVisionLoader": "Load CLIP Vision", "CLIPVisionLoader": "Load CLIP Vision",
"UpscaleModelLoader": "Load Upscale Model",
"UNETLoader": "Load Diffusion Model", "UNETLoader": "Load Diffusion Model",
# Conditioning # Conditioning
"CLIPVisionEncode": "CLIP Vision Encode", "CLIPVisionEncode": "CLIP Vision Encode",
@ -2065,7 +2064,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoadImageOutput": "Load Image (from Outputs)", "LoadImageOutput": "Load Image (from Outputs)",
"ImageScale": "Upscale Image", "ImageScale": "Upscale Image",
"ImageScaleBy": "Upscale Image By", "ImageScaleBy": "Upscale Image By",
"ImageUpscaleWithModel": "Upscale Image (using Model)",
"ImageInvert": "Invert Image", "ImageInvert": "Invert Image",
"ImagePadForOutpaint": "Pad Image for Outpainting", "ImagePadForOutpaint": "Pad Image for Outpainting",
"ImageBatch": "Batch Images", "ImageBatch": "Batch Images",

View File

@ -61,7 +61,6 @@ messages_control.disable = [
# next warnings should be fixed in future # next warnings should be fixed in future
"bad-classmethod-argument", # Class method should have 'cls' as first argument "bad-classmethod-argument", # Class method should have 'cls' as first argument
"wrong-import-order", # Standard imports should be placed before third party imports "wrong-import-order", # Standard imports should be placed before third party imports
"logging-fstring-interpolation", # Use lazy % formatting in logging functions
"ungrouped-imports", "ungrouped-imports",
"unnecessary-pass", "unnecessary-pass",
"unnecessary-lambda-assignment", "unnecessary-lambda-assignment",