Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data 2025-10-10 08:15:03 +09:00
commit 4e7f2eeae2
15 changed files with 722 additions and 542 deletions

View File

@ -123,16 +123,30 @@ def move_weight_functions(m, device):
return memory return memory
class LowVramPatch: class LowVramPatch:
def __init__(self, key, patches): def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key self.key = key
self.patches = patches self.patches = patches
self.convert_func = convert_func
self.set_func = set_func
def __call__(self, weight): def __call__(self, weight):
intermediate_dtype = weight.dtype intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32 intermediate_dtype = torch.float32
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key)) out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
def get_key_weight(model, key): def get_key_weight(model, key):
set_func = None set_func = None
@ -657,13 +671,15 @@ class ModelPatcher:
if force_patch_weights: if force_patch_weights:
self.patch_weight_to_device(weight_key) self.patch_weight_to_device(weight_key)
else: else:
m.weight_function = [LowVramPatch(weight_key, self.patches)] _, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
patch_counter += 1 patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
if force_patch_weights: if force_patch_weights:
self.patch_weight_to_device(bias_key) self.patch_weight_to_device(bias_key)
else: else:
m.bias_function = [LowVramPatch(bias_key, self.patches)] _, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
patch_counter += 1 patch_counter += 1
cast_weight = True cast_weight = True
@ -825,10 +841,12 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to) module_mem += move_weight_functions(m, device_to)
if lowvram_possible: if lowvram_possible:
if weight_key in self.patches: if weight_key in self.patches:
m.weight_function.append(LowVramPatch(weight_key, self.patches)) _, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1 patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
m.bias_function.append(LowVramPatch(bias_key, self.patches)) _, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1 patch_counter += 1
cast_weight = True cast_weight = True

View File

@ -416,8 +416,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
else: else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update: if inplace_update:
self.weight.data.copy_(weight) self.weight.data.copy_(weight)
else: else:

View File

@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from comfy_api.latest._io import _IO as io #noqa: F401 from . import _io as io
from comfy_api.latest._ui import _UI as ui #noqa: F401 from . import _ui as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple from comfy_execution.progress import get_progress_state, PreviewImageTuple
@ -114,6 +114,8 @@ if TYPE_CHECKING:
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest) ComfyAPISync = create_sync_class(ComfyAPI_latest)
comfy_io = io # create the new alias for io
__all__ = [ __all__ = [
"ComfyAPI", "ComfyAPI",
"ComfyAPISync", "ComfyAPISync",
@ -121,4 +123,7 @@ __all__ = [
"InputImpl", "InputImpl",
"Types", "Types",
"ComfyExtension", "ComfyExtension",
"io",
"comfy_io",
"ui",
] ]

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Union from typing import Optional, Union, IO
import io import io
import av import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
@ -23,7 +23,7 @@ class VideoInput(ABC):
@abstractmethod @abstractmethod
def save_to( def save_to(
self, self,
path: str, path: Union[str, IO[bytes]],
format: VideoContainer = VideoContainer.AUTO, format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None metadata: Optional[dict] = None

View File

@ -1582,78 +1582,78 @@ class _UIOutput(ABC):
... ...
class _IO: __all__ = [
FolderType = FolderType "FolderType",
UploadType = UploadType "UploadType",
RemoteOptions = RemoteOptions "RemoteOptions",
NumberDisplay = NumberDisplay "NumberDisplay",
comfytype = staticmethod(comfytype) "comfytype",
Custom = staticmethod(Custom) "Custom",
Input = Input "Input",
WidgetInput = WidgetInput "WidgetInput",
Output = Output "Output",
ComfyTypeI = ComfyTypeI "ComfyTypeI",
ComfyTypeIO = ComfyTypeIO "ComfyTypeIO",
#---------------------------------
# Supported Types # Supported Types
Boolean = Boolean "Boolean",
Int = Int "Int",
Float = Float "Float",
String = String "String",
Combo = Combo "Combo",
MultiCombo = MultiCombo "MultiCombo",
Image = Image "Image",
WanCameraEmbedding = WanCameraEmbedding "WanCameraEmbedding",
Webcam = Webcam "Webcam",
Mask = Mask "Mask",
Latent = Latent "Latent",
Conditioning = Conditioning "Conditioning",
Sampler = Sampler "Sampler",
Sigmas = Sigmas "Sigmas",
Noise = Noise "Noise",
Guider = Guider "Guider",
Clip = Clip "Clip",
ControlNet = ControlNet "ControlNet",
Vae = Vae "Vae",
Model = Model "Model",
ClipVision = ClipVision "ClipVision",
ClipVisionOutput = ClipVisionOutput "ClipVisionOutput",
AudioEncoder = AudioEncoder "AudioEncoder",
AudioEncoderOutput = AudioEncoderOutput "AudioEncoderOutput",
StyleModel = StyleModel "StyleModel",
Gligen = Gligen "Gligen",
UpscaleModel = UpscaleModel "UpscaleModel",
Audio = Audio "Audio",
Video = Video "Video",
SVG = SVG "SVG",
LoraModel = LoraModel "LoraModel",
LossMap = LossMap "LossMap",
Voxel = Voxel "Voxel",
Mesh = Mesh "Mesh",
Hooks = Hooks "Hooks",
HookKeyframes = HookKeyframes "HookKeyframes",
TimestepsRange = TimestepsRange "TimestepsRange",
LatentOperation = LatentOperation "LatentOperation",
FlowControl = FlowControl "FlowControl",
Accumulation = Accumulation "Accumulation",
Load3DCamera = Load3DCamera "Load3DCamera",
Load3D = Load3D "Load3D",
Load3DAnimation = Load3DAnimation "Load3DAnimation",
Photomaker = Photomaker "Photomaker",
Point = Point "Point",
FaceAnalysis = FaceAnalysis "FaceAnalysis",
BBOX = BBOX "BBOX",
SEGS = SEGS "SEGS",
AnyType = AnyType "AnyType",
MultiType = MultiType "MultiType",
#--------------------------------- # Other classes
HiddenHolder = HiddenHolder "HiddenHolder",
Hidden = Hidden "Hidden",
NodeInfoV1 = NodeInfoV1 "NodeInfoV1",
NodeInfoV3 = NodeInfoV3 "NodeInfoV3",
Schema = Schema "Schema",
ComfyNode = ComfyNode "ComfyNode",
NodeOutput = NodeOutput "NodeOutput",
add_to_dict_v1 = staticmethod(add_to_dict_v1) "add_to_dict_v1",
add_to_dict_v3 = staticmethod(add_to_dict_v3) "add_to_dict_v3",
]

View File

@ -449,15 +449,16 @@ class PreviewText(_UIOutput):
return {"text": (self.value,)} return {"text": (self.value,)}
class _UI: __all__ = [
SavedResult = SavedResult "SavedResult",
SavedImages = SavedImages "SavedImages",
SavedAudios = SavedAudios "SavedAudios",
ImageSaveHelper = ImageSaveHelper "ImageSaveHelper",
AudioSaveHelper = AudioSaveHelper "AudioSaveHelper",
PreviewImage = PreviewImage "PreviewImage",
PreviewMask = PreviewMask "PreviewMask",
PreviewAudio = PreviewAudio "PreviewAudio",
PreviewVideo = PreviewVideo "PreviewVideo",
PreviewUI3D = PreviewUI3D "PreviewUI3D",
PreviewText = PreviewText "PreviewText",
]

View File

@ -269,7 +269,7 @@ def tensor_to_bytesio(
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns: Returns:
Named BytesIO object containing the image data. Named BytesIO object containing the image data, with pointer set to the start of buffer.
""" """
if not mime_type: if not mime_type:
mime_type = "image/png" mime_type = "image/png"

View File

@ -98,7 +98,7 @@ import io
import os import os
import socket import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple from typing import Type, Optional, Any, TypeVar, Generic, Callable
from enum import Enum from enum import Enum
import json import json
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
@ -175,7 +175,7 @@ class ApiClient:
max_retries: int = 3, max_retries: int = 3,
retry_delay: float = 1.0, retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0, retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[Tuple[int, ...]] = None, retry_status_codes: Optional[tuple[int, ...]] = None,
session: Optional[aiohttp.ClientSession] = None, session: Optional[aiohttp.ClientSession] = None,
): ):
self.base_url = base_url self.base_url = base_url
@ -199,9 +199,9 @@ class ApiClient:
@staticmethod @staticmethod
def _create_json_payload_args( def _create_json_payload_args(
data: Optional[Dict[str, Any]] = None, data: Optional[dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"json": data, "json": data,
"headers": headers, "headers": headers,
@ -209,11 +209,11 @@ class ApiClient:
def _create_form_data_args( def _create_form_data_args(
self, self,
data: Dict[str, Any] | None, data: dict[str, Any] | None,
files: Dict[str, Any] | None, files: dict[str, Any] | None,
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
multipart_parser: Callable | None = None, multipart_parser: Callable | None = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
if headers and "Content-Type" in headers: if headers and "Content-Type" in headers:
del headers["Content-Type"] del headers["Content-Type"]
@ -254,9 +254,9 @@ class ApiClient:
@staticmethod @staticmethod
def _create_urlencoded_form_data_args( def _create_urlencoded_form_data_args(
data: Dict[str, Any], data: dict[str, Any],
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
) -> Dict[str, Any]: ) -> dict[str, Any]:
headers = headers or {} headers = headers or {}
headers["Content-Type"] = "application/x-www-form-urlencoded" headers["Content-Type"] = "application/x-www-form-urlencoded"
return { return {
@ -264,7 +264,7 @@ class ApiClient:
"headers": headers, "headers": headers,
} }
def get_headers(self) -> Dict[str, str]: def get_headers(self) -> dict[str, str]:
"""Get headers for API requests, including authentication if available""" """Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"} headers = {"Content-Type": "application/json", "Accept": "application/json"}
@ -275,7 +275,7 @@ class ApiClient:
return headers return headers
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]: async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
""" """
Check connectivity to determine if network issues are local or server-related. Check connectivity to determine if network issues are local or server-related.
@ -316,14 +316,14 @@ class ApiClient:
self, self,
method: str, method: str,
path: str, path: str,
params: Optional[Dict[str, Any]] = None, params: Optional[dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None, data: Optional[dict[str, Any]] = None,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[Dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
content_type: str = "application/json", content_type: str = "application/json",
multipart_parser: Callable | None = None, multipart_parser: Callable | None = None,
retry_count: int = 0, # Used internally for tracking retries retry_count: int = 0, # Used internally for tracking retries
) -> Dict[str, Any]: ) -> dict[str, Any]:
""" """
Make an HTTP request to the API with automatic retries for transient errors. Make an HTTP request to the API with automatic retries for transient errors.
@ -485,7 +485,7 @@ class ApiClient:
retry_delay: Initial delay between retries in seconds retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry retry_backoff_factor: Multiplier for the delay after each retry
""" """
headers: Dict[str, str] = {} headers: dict[str, str] = {}
skip_auto_headers: set[str] = set() skip_auto_headers: set[str] = set()
if content_type: if content_type:
headers["Content-Type"] = content_type headers["Content-Type"] = content_type
@ -558,7 +558,7 @@ class ApiClient:
*req_meta, *req_meta,
retry_count: int, retry_count: int,
response_content: dict | str = "", response_content: dict | str = "",
) -> Dict[str, Any]: ) -> dict[str, Any]:
status_code = exc.status status_code = exc.status
if status_code == 401: if status_code == 401:
user_friendly = "Unauthorized: Please login first to use this node." user_friendly = "Unauthorized: Please login first to use this node."
@ -659,7 +659,7 @@ class ApiEndpoint(Generic[T, R]):
method: HttpMethod, method: HttpMethod,
request_model: Type[T], request_model: Type[T],
response_model: Type[R], response_model: Type[R],
query_params: Optional[Dict[str, Any]] = None, query_params: Optional[dict[str, Any]] = None,
): ):
"""Initialize an API endpoint definition. """Initialize an API endpoint definition.
@ -684,11 +684,11 @@ class SynchronousOperation(Generic[T, R]):
self, self,
endpoint: ApiEndpoint[T, R], endpoint: ApiEndpoint[T, R],
request: T, request: T,
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
api_base: str | None = None, api_base: str | None = None,
auth_token: Optional[str] = None, auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None, comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
timeout: float = 7200.0, timeout: float = 7200.0,
verify_ssl: bool = True, verify_ssl: bool = True,
content_type: str = "application/json", content_type: str = "application/json",
@ -729,7 +729,7 @@ class SynchronousOperation(Generic[T, R]):
) )
try: try:
request_dict: Optional[Dict[str, Any]] request_dict: Optional[dict[str, Any]]
if isinstance(self.request, EmptyRequest): if isinstance(self.request, EmptyRequest):
request_dict = None request_dict = None
else: else:
@ -782,14 +782,14 @@ class PollingOperation(Generic[T, R]):
poll_endpoint: ApiEndpoint[EmptyRequest, R], poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str], completed_statuses: list[str],
failed_statuses: list[str], failed_statuses: list[str],
status_extractor: Callable[[R], str], status_extractor: Callable[[R], Optional[str]],
progress_extractor: Callable[[R], float] | None = None, progress_extractor: Callable[[R], Optional[float]] | None = None,
result_url_extractor: Callable[[R], str] | None = None, result_url_extractor: Callable[[R], Optional[str]] | None = None,
request: Optional[T] = None, request: Optional[T] = None,
api_base: str | None = None, api_base: str | None = None,
auth_token: Optional[str] = None, auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None, comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
poll_interval: float = 5.0, poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call max_retries: int = 3, # Max retries per individual API call

View File

@ -0,0 +1,100 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -8,30 +8,17 @@ from __future__ import annotations
from io import BytesIO from io import BytesIO
import logging import logging
from typing import Optional, TypeVar from typing import Optional, TypeVar
from enum import Enum
import numpy as np
import torch import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io from comfy_api.latest import ComfyExtension, comfy_io
from comfy_api.input_impl import VideoFromFile
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apinode_utils import ( from comfy_api_nodes.apinode_utils import (
download_url_to_video_output, download_url_to_video_output,
tensor_to_bytesio, tensor_to_bytesio,
) )
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import pika_defs
PikaBodyGenerate22C2vGenerate22PikascenesPost,
PikaBodyGenerate22I2vGenerate22I2vPost,
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
PikaBodyGenerate22T2vGenerate22T2vPost,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
PikaGenerateResponse,
PikaVideoResponse,
)
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.apis.client import (
ApiEndpoint, ApiEndpoint,
EmptyRequest, EmptyRequest,
@ -55,116 +42,36 @@ PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos" PATH_VIDEO_GET = "/proxy/pika/videos"
class PikaDurationEnum(int, Enum): async def execute_task(
integer_5 = 5 initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
integer_10 = 10
class PikaResolutionEnum(str, Enum):
field_1080p = "1080p"
field_720p = "720p"
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaApiError(Exception):
"""Exception for Pika API errors."""
pass
def is_valid_video_response(response: PikaVideoResponse) -> bool:
"""Check if the video response is valid."""
return hasattr(response, "url") and response.url is not None
def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
"""Check if the initial response is valid."""
return hasattr(response, "video_id") and response.video_id is not None
async def poll_for_task_status(
task_id: str,
auth_kwargs: Optional[dict[str, str]] = None, auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None, node_id: Optional[str] = None,
) -> PikaGenerateResponse: ) -> comfy_io.NodeOutput:
polling_operation = PollingOperation( task_id = (await initial_operation.execute()).video_id
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
poll_endpoint=ApiEndpoint( poll_endpoint=ApiEndpoint(
path=f"{PATH_VIDEO_GET}/{task_id}", path=f"{PATH_VIDEO_GET}/{task_id}",
method=HttpMethod.GET, method=HttpMethod.GET,
request_model=EmptyRequest, request_model=EmptyRequest,
response_model=PikaVideoResponse, response_model=pika_defs.PikaVideoResponse,
), ),
completed_statuses=[ completed_statuses=["finished"],
"finished",
],
failed_statuses=["failed", "cancelled"], failed_statuses=["failed", "cancelled"],
status_extractor=lambda response: ( status_extractor=lambda response: (response.status.value if response.status else None),
response.status.value if response.status else None progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
),
progress_extractor=lambda response: (
response.progress if hasattr(response, "progress") else None
),
auth_kwargs=auth_kwargs, auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: ( result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
response.url if hasattr(response, "url") else None
),
node_id=node_id, node_id=node_id,
estimated_duration=60 estimated_duration=60,
) max_poll_attempts=240,
return await polling_operation.execute() ).execute()
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
async def execute_task(
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]:
"""Executes the initial operation then polls for the task status until it is completed.
Args:
initial_operation: The initial operation to execute.
auth_kwargs: The authentication token(s) to use for the API call.
Returns:
A tuple containing the video file as a VIDEO output.
"""
initial_response = await initial_operation.execute()
if not is_valid_initial_response(initial_response):
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
logging.error(error_msg) logging.error(error_msg)
raise PikaApiError(error_msg) raise Exception(error_msg)
video_url = final_response.url
task_id = initial_response.video_id
final_response = await poll_for_task_status(task_id, auth_kwargs, node_id=node_id)
if not is_valid_video_response(final_response):
error_msg = (
f"Pika task {task_id} succeeded but no video data found in response."
)
logging.error(error_msg)
raise PikaApiError(error_msg)
video_url = str(final_response.url)
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url) logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return comfy_io.NodeOutput(await download_url_to_video_output(video_url))
return (await download_url_to_video_output(video_url),)
def get_base_inputs_types() -> list[comfy_io.Input]: def get_base_inputs_types() -> list[comfy_io.Input]:
@ -173,16 +80,12 @@ def get_base_inputs_types() -> list[comfy_io.Input]:
comfy_io.String.Input("prompt_text", multiline=True), comfy_io.String.Input("prompt_text", multiline=True),
comfy_io.String.Input("negative_prompt", multiline=True), comfy_io.String.Input("negative_prompt", multiline=True),
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True), comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
comfy_io.Combo.Input( comfy_io.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
"resolution", options=PikaResolutionEnum, default=PikaResolutionEnum.field_1080p comfy_io.Combo.Input("duration", options=[5, 10], default=5),
),
comfy_io.Combo.Input(
"duration", options=PikaDurationEnum, default=PikaDurationEnum.integer_5
),
] ]
class PikaImageToVideoV2_2(comfy_io.ComfyNode): class PikaImageToVideo(comfy_io.ComfyNode):
"""Pika 2.2 Image to Video Node.""" """Pika 2.2 Image to Video Node."""
@classmethod @classmethod
@ -215,14 +118,9 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
resolution: str, resolution: str,
duration: int, duration: int,
) -> comfy_io.NodeOutput: ) -> comfy_io.NodeOutput:
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image) image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")} pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
# Prepare non-file data
pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
@ -237,8 +135,8 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO, path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGenerate22I2vGenerate22I2vPost, request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=pika_request_data, request=pika_request_data,
files=pika_files, files=pika_files,
@ -248,7 +146,7 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode): class PikaTextToVideoNode(comfy_io.ComfyNode):
"""Pika Text2Video v2.2 Node.""" """Pika Text2Video v2.2 Node."""
@classmethod @classmethod
@ -296,10 +194,10 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_TEXT_TO_VIDEO, path=PATH_TEXT_TO_VIDEO,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGenerate22T2vGenerate22T2vPost, request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=PikaBodyGenerate22T2vGenerate22T2vPost( request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
@ -313,7 +211,7 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaScenesV2_2(comfy_io.ComfyNode): class PikaScenes(comfy_io.ComfyNode):
"""PikaScenes v2.2 Node.""" """PikaScenes v2.2 Node."""
@classmethod @classmethod
@ -389,7 +287,6 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
image_ingredient_4: Optional[torch.Tensor] = None, image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None, image_ingredient_5: Optional[torch.Tensor] = None,
) -> comfy_io.NodeOutput: ) -> comfy_io.NodeOutput:
# Convert all passed images to BytesIO
all_image_bytes_io = [] all_image_bytes_io = []
for image in [ for image in [
image_ingredient_1, image_ingredient_1,
@ -399,16 +296,14 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
image_ingredient_5, image_ingredient_5,
]: ]:
if image is not None: if image is not None:
image_bytes_io = tensor_to_bytesio(image) all_image_bytes_io.append(tensor_to_bytesio(image))
image_bytes_io.seek(0)
all_image_bytes_io.append(image_bytes_io)
pika_files = [ pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png")) ("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io) for i, image_bytes_io in enumerate(all_image_bytes_io)
] ]
pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost( pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode, ingredientsMode=ingredients_mode,
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
@ -425,8 +320,8 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_PIKASCENES, path=PATH_PIKASCENES,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost, request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=pika_request_data, request=pika_request_data,
files=pika_files, files=pika_files,
@ -477,22 +372,16 @@ class PikAdditionsNode(comfy_io.ComfyNode):
negative_prompt: str, negative_prompt: str,
seed: int, seed: int,
) -> comfy_io.NodeOutput: ) -> comfy_io.NodeOutput:
# Convert video to BytesIO
video_bytes_io = BytesIO() video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264) video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0) video_bytes_io.seek(0)
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image) image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = { pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"), "video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"), "image": ("image.png", image_bytes_io, "image/png"),
} }
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
# Prepare non-file data
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
@ -505,8 +394,8 @@ class PikAdditionsNode(comfy_io.ComfyNode):
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS, path=PATH_PIKADDITIONS,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost, request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=pika_request_data, request=pika_request_data,
files=pika_files, files=pika_files,
@ -529,11 +418,25 @@ class PikaSwapsNode(comfy_io.ComfyNode):
category="api node/video/Pika", category="api node/video/Pika",
inputs=[ inputs=[
comfy_io.Video.Input("video", tooltip="The video to swap an object in."), comfy_io.Video.Input("video", tooltip="The video to swap an object in."),
comfy_io.Image.Input("image", tooltip="The image used to replace the masked object in the video."), comfy_io.Image.Input(
comfy_io.Mask.Input("mask", tooltip="Use the mask to define areas in the video to replace"), "image",
comfy_io.String.Input("prompt_text", multiline=True), tooltip="The image used to replace the masked object in the video.",
comfy_io.String.Input("negative_prompt", multiline=True), optional=True,
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True), ),
comfy_io.Mask.Input(
"mask",
tooltip="Use the mask to define areas in the video to replace.",
optional=True,
),
comfy_io.String.Input("prompt_text", multiline=True, optional=True),
comfy_io.String.Input("negative_prompt", multiline=True, optional=True),
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
comfy_io.String.Input(
"region_to_modify",
multiline=True,
optional=True,
tooltip="Plaintext description of the object / region to modify.",
),
], ],
outputs=[comfy_io.Video.Output()], outputs=[comfy_io.Video.Output()],
hidden=[ hidden=[
@ -548,41 +451,29 @@ class PikaSwapsNode(comfy_io.ComfyNode):
async def execute( async def execute(
cls, cls,
video: VideoInput, video: VideoInput,
image: torch.Tensor, image: Optional[torch.Tensor] = None,
mask: torch.Tensor, mask: Optional[torch.Tensor] = None,
prompt_text: str, prompt_text: str = "",
negative_prompt: str, negative_prompt: str = "",
seed: int, seed: int = 0,
region_to_modify: str = "",
) -> comfy_io.NodeOutput: ) -> comfy_io.NodeOutput:
# Convert video to BytesIO
video_bytes_io = BytesIO() video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264) video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0) video_bytes_io.seek(0)
# Convert mask to binary mask with three channels
mask = torch.round(mask)
mask = mask.repeat(1, 3, 1, 1)
# Convert 3-channel binary mask to BytesIO
mask_bytes_io = BytesIO()
mask_bytes_io.write(mask.numpy().astype(np.uint8))
mask_bytes_io.seek(0)
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image)
image_bytes_io.seek(0)
pika_files = { pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"), "video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
} }
if mask is not None:
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
if image is not None:
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
# Prepare non-file data pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
) )
auth = { auth = {
"auth_token": cls.hidden.auth_token_comfy_org, "auth_token": cls.hidden.auth_token_comfy_org,
@ -590,10 +481,10 @@ class PikaSwapsNode(comfy_io.ComfyNode):
} }
initial_operation = SynchronousOperation( initial_operation = SynchronousOperation(
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS, path=PATH_PIKASWAPS,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost, request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=pika_request_data, request=pika_request_data,
files=pika_files, files=pika_files,
@ -616,7 +507,7 @@ class PikaffectsNode(comfy_io.ComfyNode):
inputs=[ inputs=[
comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."), comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
comfy_io.Combo.Input( comfy_io.Combo.Input(
"pikaffect", options=Pikaffect, default="Cake-ify" "pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
), ),
comfy_io.String.Input("prompt_text", multiline=True), comfy_io.String.Input("prompt_text", multiline=True),
comfy_io.String.Input("negative_prompt", multiline=True), comfy_io.String.Input("negative_prompt", multiline=True),
@ -648,10 +539,10 @@ class PikaffectsNode(comfy_io.ComfyNode):
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_PIKAFFECTS, path=PATH_PIKAFFECTS,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost, request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost( request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect, pikaffect=pikaffect,
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
@ -664,7 +555,7 @@ class PikaffectsNode(comfy_io.ComfyNode):
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
class PikaStartEndFrameNode2_2(comfy_io.ComfyNode): class PikaStartEndFrameNode(comfy_io.ComfyNode):
"""PikaFrames v2.2 Node.""" """PikaFrames v2.2 Node."""
@classmethod @classmethod
@ -711,10 +602,10 @@ class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
endpoint=ApiEndpoint( endpoint=ApiEndpoint(
path=PATH_PIKAFRAMES, path=PATH_PIKAFRAMES,
method=HttpMethod.POST, method=HttpMethod.POST,
request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost, request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
response_model=PikaGenerateResponse, response_model=pika_defs.PikaGenerateResponse,
), ),
request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost( request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
@ -732,13 +623,13 @@ class PikaApiNodesExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [ return [
PikaImageToVideoV2_2, PikaImageToVideo,
PikaTextToVideoNodeV2_2, PikaTextToVideoNode,
PikaScenesV2_2, PikaScenes,
PikAdditionsNode, PikAdditionsNode,
PikaSwapsNode, PikaSwapsNode,
PikaffectsNode, PikaffectsNode,
PikaStartEndFrameNode2_2, PikaStartEndFrameNode,
] ]

View File

@ -1,60 +1,80 @@
import node_helpers import node_helpers
import comfy.utils import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodeFlux:
class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"clip": ("CLIP", ), node_id="CLIPTextEncodeFlux",
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), category="advanced/conditioning/flux",
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), inputs=[
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), io.Clip.Input("clip"),
}} io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
RETURN_TYPES = ("CONDITIONING",) io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
FUNCTION = "encode" io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning/flux" @classmethod
def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput:
def encode(self, clip, clip_l, t5xxl, guidance):
tokens = clip.tokenize(clip_l) tokens = clip.tokenize(clip_l)
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
class FluxGuidance: encode = execute # TODO: remove
class FluxGuidance(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"conditioning": ("CONDITIONING", ), node_id="FluxGuidance",
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), category="advanced/conditioning/flux",
}} inputs=[
io.Conditioning.Input("conditioning"),
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "append" def execute(cls, conditioning, guidance) -> io.NodeOutput:
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, guidance):
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance}) c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
return (c, ) return io.NodeOutput(c)
append = execute # TODO: remove
class FluxDisableGuidance: class FluxDisableGuidance(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"conditioning": ("CONDITIONING", ), node_id="FluxDisableGuidance",
}} category="advanced/conditioning/flux",
description="This node completely disables the guidance embed on Flux and Flux like models",
inputs=[
io.Conditioning.Input("conditioning"),
],
outputs=[
io.Conditioning.Output(),
],
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "append" def execute(cls, conditioning) -> io.NodeOutput:
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
def append(self, conditioning):
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None}) c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
return (c, ) return io.NodeOutput(c)
append = execute # TODO: remove
PREFERED_KONTEXT_RESOLUTIONS = [ PREFERED_KONTEXT_RESOLUTIONS = [
@ -78,52 +98,73 @@ PREFERED_KONTEXT_RESOLUTIONS = [
] ]
class FluxKontextImageScale: class FluxKontextImageScale(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"image": ("IMAGE", ), return io.Schema(
}, node_id="FluxKontextImageScale",
} category="advanced/conditioning/flux",
description="This node resizes the image to one that is more optimal for flux kontext.",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
],
)
RETURN_TYPES = ("IMAGE",) @classmethod
FUNCTION = "scale" def execute(cls, image) -> io.NodeOutput:
CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
def scale(self, image):
width = image.shape[2] width = image.shape[2]
height = image.shape[1] height = image.shape[1]
aspect_ratio = width / height aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return (image, ) return io.NodeOutput(image)
scale = execute # TODO: remove
class FluxKontextMultiReferenceLatentMethod: class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"conditioning": ("CONDITIONING", ), node_id="FluxKontextMultiReferenceLatentMethod",
"reference_latents_method": (("offset", "index", "uxo/uno"), ), category="advanced/conditioning/flux",
}} inputs=[
io.Conditioning.Input("conditioning"),
io.Combo.Input(
"reference_latents_method",
options=["offset", "index", "uxo/uno"],
),
],
outputs=[
io.Conditioning.Output(),
],
is_experimental=True,
)
RETURN_TYPES = ("CONDITIONING",) @classmethod
FUNCTION = "append" def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput:
EXPERIMENTAL = True
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
if "uxo" in reference_latents_method or "uso" in reference_latents_method: if "uxo" in reference_latents_method or "uso" in reference_latents_method:
reference_latents_method = "uxo" reference_latents_method = "uxo"
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method}) c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, ) return io.NodeOutput(c)
NODE_CLASS_MAPPINGS = { append = execute # TODO: remove
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance, class FluxExtension(ComfyExtension):
"FluxKontextImageScale": FluxKontextImageScale, @override
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod, async def get_node_list(self) -> list[type[io.ComfyNode]]:
} return [
CLIPTextEncodeFlux,
FluxGuidance,
FluxDisableGuidance,
FluxKontextImageScale,
FluxKontextMultiReferenceLatentMethod,
]
async def comfy_entrypoint() -> FluxExtension:
return FluxExtension()

View File

@ -3,64 +3,83 @@ import comfy.sd
import comfy.model_management import comfy.model_management
import nodes import nodes
import torch import torch
import comfy_extras.nodes_slg from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_slg import SkipLayerGuidanceDiT
class TripleCLIPLoader: class TripleCLIPLoader(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), ) return io.Schema(
}} node_id="TripleCLIPLoader",
RETURN_TYPES = ("CLIP",) category="advanced/loaders",
FUNCTION = "load_clip" description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
inputs=[
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
],
outputs=[
io.Clip.Output(),
],
)
CATEGORY = "advanced/loaders" @classmethod
def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput:
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,) return io.NodeOutput(clip)
load_clip = execute # TODO: remove
class EmptySD3LatentImage: class EmptySD3LatentImage(io.ComfyNode):
def __init__(self): @classmethod
self.device = comfy.model_management.intermediate_device() def define_schema(cls):
return io.Schema(
node_id="EmptySD3LatentImage",
category="latent/sd3",
inputs=[
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), return io.NodeOutput({"samples":latent})
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/sd3" generate = execute # TODO: remove
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )
class CLIPTextEncodeSD3: class CLIPTextEncodeSD3(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return io.Schema(
"clip": ("CLIP", ), node_id="CLIPTextEncodeSD3",
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), category="advanced/conditioning",
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), inputs=[
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), io.Clip.Input("clip"),
"empty_padding": (["none", "empty_prompt"], ) io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
}} io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
RETURN_TYPES = ("CONDITIONING",) io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
FUNCTION = "encode" io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
],
outputs=[
io.Conditioning.Output(),
],
)
CATEGORY = "advanced/conditioning" @classmethod
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput:
def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
no_padding = empty_padding == "none" no_padding = empty_padding == "none"
tokens = clip.tokenize(clip_g) tokens = clip.tokenize(clip_g)
@ -82,57 +101,112 @@ class CLIPTextEncodeSD3:
tokens["l"] += empty["l"] tokens["l"] += empty["l"]
while len(tokens["l"]) > len(tokens["g"]): while len(tokens["l"]) > len(tokens["g"]):
tokens["g"] += empty["g"] tokens["g"] += empty["g"]
return (clip.encode_from_tokens_scheduled(tokens), ) return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
encode = execute # TODO: remove
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): class ControlNetApplySD3(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls) -> io.Schema:
return {"required": {"positive": ("CONDITIONING", ), return io.Schema(
"negative": ("CONDITIONING", ), node_id="ControlNetApplySD3",
"control_net": ("CONTROL_NET", ), display_name="Apply Controlnet with VAE",
"vae": ("VAE", ), category="conditioning/controlnet",
"image": ("IMAGE", ), inputs=[
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), io.Conditioning.Input("positive"),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), io.Conditioning.Input("negative"),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) io.ControlNet.Input("control_net"),
}} io.Vae.Input("vae"),
CATEGORY = "conditioning/controlnet" io.Image.Input("image"),
DEPRECATED = True io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
is_deprecated=True,
)
@classmethod
def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput:
if strength == 0:
return io.NodeOutput(positive, negative)
control_hint = image.movedim(-1, 1)
cnets = {}
out = []
for conditioning in [positive, negative]:
c = []
for t in conditioning:
d = t[1].copy()
prev_cnet = d.get('control', None)
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent),
vae=vae, extra_concat=[])
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
d['control'] = c_net
d['control_apply_to_uncond'] = False
n = [t[0], d]
c.append(n)
out.append(c)
return io.NodeOutput(out[0], out[1])
apply_controlnet = execute # TODO: remove
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT): class SkipLayerGuidanceSD3(io.ComfyNode):
''' '''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Experimental implementation by Dango233@StabilityAI. Experimental implementation by Dango233@StabilityAI.
''' '''
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), node_id="SkipLayerGuidanceSD3",
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), category="advanced/guidance",
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) inputs=[
}} io.Model.Input("model"),
RETURN_TYPES = ("MODEL",) io.String.Input("layers", default="7, 8, 9", multiline=False),
FUNCTION = "skip_guidance_sd3" io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Model.Output(),
],
is_experimental=True,
)
CATEGORY = "advanced/guidance" @classmethod
def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput:
return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent): skip_guidance_sd3 = execute # TODO: remove
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
NODE_CLASS_MAPPINGS = { class SD3Extension(ComfyExtension):
"TripleCLIPLoader": TripleCLIPLoader, @override
"EmptySD3LatentImage": EmptySD3LatentImage, async def get_node_list(self) -> list[type[io.ComfyNode]]:
"CLIPTextEncodeSD3": CLIPTextEncodeSD3, return [
"ControlNetApplySD3": ControlNetApplySD3, TripleCLIPLoader,
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3, EmptySD3LatentImage,
} CLIPTextEncodeSD3,
ControlNetApplySD3,
SkipLayerGuidanceSD3,
]
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling async def comfy_entrypoint() -> SD3Extension:
"ControlNetApplySD3": "Apply Controlnet with VAE", return SD3Extension()
}

View File

@ -1,33 +1,40 @@
import comfy.model_patcher import comfy.model_patcher
import comfy.samplers import comfy.samplers
import re import re
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class SkipLayerGuidanceDiT: class SkipLayerGuidanceDiT(io.ComfyNode):
''' '''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Original experimental implementation for SD3 by Dango233@StabilityAI. Original experimental implementation for SD3 by Dango233@StabilityAI.
''' '''
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), node_id="SkipLayerGuidanceDiT",
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), category="advanced/guidance",
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), is_experimental=True,
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), inputs=[
"rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), io.Model.Input("model"),
}} io.String.Input("double_layers", default="7, 8, 9"),
RETURN_TYPES = ("MODEL",) io.String.Input("single_layers", default="7, 8, 9"),
FUNCTION = "skip_guidance" io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
EXPERIMENTAL = True io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01),
],
outputs=[
io.Model.Output(),
],
)
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model." @classmethod
def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0) -> io.NodeOutput:
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
# check if layer is comma separated integers # check if layer is comma separated integers
def skip(args, extra_args): def skip(args, extra_args):
return args return args
@ -43,7 +50,7 @@ class SkipLayerGuidanceDiT:
single_layers = [int(i) for i in single_layers] single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0: if len(double_layers) == 0 and len(single_layers) == 0:
return (model, ) return io.NodeOutput(model)
def post_cfg_function(args): def post_cfg_function(args):
model = args["model"] model = args["model"]
@ -76,29 +83,36 @@ class SkipLayerGuidanceDiT:
m = model.clone() m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function) m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m, ) return io.NodeOutput(m)
class SkipLayerGuidanceDiTSimple: skip_guidance = execute # TODO: remove
class SkipLayerGuidanceDiTSimple(io.ComfyNode):
''' '''
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass. Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
''' '''
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"model": ("MODEL", ), return io.Schema(
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), node_id="SkipLayerGuidanceDiTSimple",
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), category="advanced/guidance",
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.",
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), is_experimental=True,
}} inputs=[
RETURN_TYPES = ("MODEL",) io.Model.Input("model"),
FUNCTION = "skip_guidance" io.String.Input("double_layers", default="7, 8, 9"),
EXPERIMENTAL = True io.String.Input("single_layers", default="7, 8, 9"),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
],
outputs=[
io.Model.Output(),
],
)
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass." @classmethod
def execute(cls, model, start_percent, end_percent, double_layers="", single_layers="") -> io.NodeOutput:
CATEGORY = "advanced/guidance"
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
def skip(args, extra_args): def skip(args, extra_args):
return args return args
@ -113,7 +127,7 @@ class SkipLayerGuidanceDiTSimple:
single_layers = [int(i) for i in single_layers] single_layers = [int(i) for i in single_layers]
if len(double_layers) == 0 and len(single_layers) == 0: if len(double_layers) == 0 and len(single_layers) == 0:
return (model, ) return io.NodeOutput(model)
def calc_cond_batch_function(args): def calc_cond_batch_function(args):
x = args["input"] x = args["input"]
@ -144,9 +158,19 @@ class SkipLayerGuidanceDiTSimple:
m = model.clone() m = model.clone()
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function) m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
return (m, ) return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = { skip_guidance = execute # TODO: remove
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
} class SkipLayerGuidanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SkipLayerGuidanceDiT,
SkipLayerGuidanceDiTSimple,
]
async def comfy_entrypoint() -> SkipLayerGuidanceExtension:
return SkipLayerGuidanceExtension()

View File

@ -4,6 +4,8 @@ from comfy import model_management
import torch import torch
import comfy.utils import comfy.utils
import folder_paths import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
try: try:
from spandrel_extra_arches import EXTRA_REGISTRY from spandrel_extra_arches import EXTRA_REGISTRY
@ -13,17 +15,23 @@ try:
except: except:
pass pass
class UpscaleModelLoader: class UpscaleModelLoader(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ), return io.Schema(
}} node_id="UpscaleModelLoader",
RETURN_TYPES = ("UPSCALE_MODEL",) display_name="Load Upscale Model",
FUNCTION = "load_model" category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")),
],
outputs=[
io.UpscaleModel.Output(),
],
)
CATEGORY = "loaders" @classmethod
def execute(cls, model_name) -> io.NodeOutput:
def load_model(self, model_name):
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name) model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True) sd = comfy.utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
@ -33,21 +41,29 @@ class UpscaleModelLoader:
if not isinstance(out, ImageModelDescriptor): if not isinstance(out, ImageModelDescriptor):
raise Exception("Upscale model must be a single-image model.") raise Exception("Upscale model must be a single-image model.")
return (out, ) return io.NodeOutput(out)
load_model = execute # TODO: remove
class ImageUpscaleWithModel: class ImageUpscaleWithModel(io.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "upscale_model": ("UPSCALE_MODEL",), return io.Schema(
"image": ("IMAGE",), node_id="ImageUpscaleWithModel",
}} display_name="Upscale Image (using Model)",
RETURN_TYPES = ("IMAGE",) category="image/upscaling",
FUNCTION = "upscale" inputs=[
io.UpscaleModel.Input("upscale_model"),
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
],
)
CATEGORY = "image/upscaling" @classmethod
def execute(cls, upscale_model, image) -> io.NodeOutput:
def upscale(self, upscale_model, image):
device = model_management.get_torch_device() device = model_management.get_torch_device()
memory_required = model_management.module_size(upscale_model.model) memory_required = model_management.module_size(upscale_model.model)
@ -75,9 +91,19 @@ class ImageUpscaleWithModel:
upscale_model.to("cpu") upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return (s,) return io.NodeOutput(s)
NODE_CLASS_MAPPINGS = { upscale = execute # TODO: remove
"UpscaleModelLoader": UpscaleModelLoader,
"ImageUpscaleWithModel": ImageUpscaleWithModel
} class UpscaleModelExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
UpscaleModelLoader,
ImageUpscaleWithModel,
]
async def comfy_entrypoint() -> UpscaleModelExtension:
return UpscaleModelExtension()

View File

@ -2030,7 +2030,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"DiffControlNetLoader": "Load ControlNet Model (diff)", "DiffControlNetLoader": "Load ControlNet Model (diff)",
"StyleModelLoader": "Load Style Model", "StyleModelLoader": "Load Style Model",
"CLIPVisionLoader": "Load CLIP Vision", "CLIPVisionLoader": "Load CLIP Vision",
"UpscaleModelLoader": "Load Upscale Model",
"UNETLoader": "Load Diffusion Model", "UNETLoader": "Load Diffusion Model",
# Conditioning # Conditioning
"CLIPVisionEncode": "CLIP Vision Encode", "CLIPVisionEncode": "CLIP Vision Encode",
@ -2068,7 +2067,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoadImageOutput": "Load Image (from Outputs)", "LoadImageOutput": "Load Image (from Outputs)",
"ImageScale": "Upscale Image", "ImageScale": "Upscale Image",
"ImageScaleBy": "Upscale Image By", "ImageScaleBy": "Upscale Image By",
"ImageUpscaleWithModel": "Upscale Image (using Model)",
"ImageInvert": "Invert Image", "ImageInvert": "Invert Image",
"ImagePadForOutpaint": "Pad Image for Outpainting", "ImagePadForOutpaint": "Pad Image for Outpainting",
"ImageBatch": "Batch Images", "ImageBatch": "Batch Images",