mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
4e7f2eeae2
@ -123,16 +123,30 @@ def move_weight_functions(m, device):
|
|||||||
return memory
|
return memory
|
||||||
|
|
||||||
class LowVramPatch:
|
class LowVramPatch:
|
||||||
def __init__(self, key, patches):
|
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.patches = patches
|
self.patches = patches
|
||||||
|
self.convert_func = convert_func
|
||||||
|
self.set_func = set_func
|
||||||
|
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
intermediate_dtype = weight.dtype
|
intermediate_dtype = weight.dtype
|
||||||
|
if self.convert_func is not None:
|
||||||
|
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
|
||||||
|
|
||||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||||
intermediate_dtype = torch.float32
|
intermediate_dtype = torch.float32
|
||||||
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
|
||||||
|
if self.set_func is None:
|
||||||
|
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
||||||
|
else:
|
||||||
|
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
||||||
|
|
||||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||||
|
if self.set_func is not None:
|
||||||
|
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
|
||||||
def get_key_weight(model, key):
|
def get_key_weight(model, key):
|
||||||
set_func = None
|
set_func = None
|
||||||
@ -657,13 +671,15 @@ class ModelPatcher:
|
|||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(weight_key)
|
self.patch_weight_to_device(weight_key)
|
||||||
else:
|
else:
|
||||||
m.weight_function = [LowVramPatch(weight_key, self.patches)]
|
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||||
|
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(bias_key)
|
self.patch_weight_to_device(bias_key)
|
||||||
else:
|
else:
|
||||||
m.bias_function = [LowVramPatch(bias_key, self.patches)]
|
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||||
|
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
@ -825,10 +841,12 @@ class ModelPatcher:
|
|||||||
module_mem += move_weight_functions(m, device_to)
|
module_mem += move_weight_functions(m, device_to)
|
||||||
if lowvram_possible:
|
if lowvram_possible:
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function.append(LowVramPatch(weight_key, self.patches))
|
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||||
|
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function.append(LowVramPatch(bias_key, self.patches))
|
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||||
|
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
|
|
||||||
|
|||||||
@ -416,8 +416,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
else:
|
else:
|
||||||
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||||
|
if return_weight:
|
||||||
|
return weight
|
||||||
if inplace_update:
|
if inplace_update:
|
||||||
self.weight.data.copy_(weight)
|
self.weight.data.copy_(weight)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
|
|||||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||||
from comfy_api.latest._io import _IO as io #noqa: F401
|
from . import _io as io
|
||||||
from comfy_api.latest._ui import _UI as ui #noqa: F401
|
from . import _ui as ui
|
||||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||||
@ -114,6 +114,8 @@ if TYPE_CHECKING:
|
|||||||
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||||
|
|
||||||
|
comfy_io = io # create the new alias for io
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ComfyAPI",
|
"ComfyAPI",
|
||||||
"ComfyAPISync",
|
"ComfyAPISync",
|
||||||
@ -121,4 +123,7 @@ __all__ = [
|
|||||||
"InputImpl",
|
"InputImpl",
|
||||||
"Types",
|
"Types",
|
||||||
"ComfyExtension",
|
"ComfyExtension",
|
||||||
|
"io",
|
||||||
|
"comfy_io",
|
||||||
|
"ui",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union, IO
|
||||||
import io
|
import io
|
||||||
import av
|
import av
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||||
@ -23,7 +23,7 @@ class VideoInput(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_to(
|
def save_to(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: Union[str, IO[bytes]],
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
|
|||||||
@ -1582,78 +1582,78 @@ class _UIOutput(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class _IO:
|
__all__ = [
|
||||||
FolderType = FolderType
|
"FolderType",
|
||||||
UploadType = UploadType
|
"UploadType",
|
||||||
RemoteOptions = RemoteOptions
|
"RemoteOptions",
|
||||||
NumberDisplay = NumberDisplay
|
"NumberDisplay",
|
||||||
|
|
||||||
comfytype = staticmethod(comfytype)
|
"comfytype",
|
||||||
Custom = staticmethod(Custom)
|
"Custom",
|
||||||
Input = Input
|
"Input",
|
||||||
WidgetInput = WidgetInput
|
"WidgetInput",
|
||||||
Output = Output
|
"Output",
|
||||||
ComfyTypeI = ComfyTypeI
|
"ComfyTypeI",
|
||||||
ComfyTypeIO = ComfyTypeIO
|
"ComfyTypeIO",
|
||||||
#---------------------------------
|
|
||||||
# Supported Types
|
# Supported Types
|
||||||
Boolean = Boolean
|
"Boolean",
|
||||||
Int = Int
|
"Int",
|
||||||
Float = Float
|
"Float",
|
||||||
String = String
|
"String",
|
||||||
Combo = Combo
|
"Combo",
|
||||||
MultiCombo = MultiCombo
|
"MultiCombo",
|
||||||
Image = Image
|
"Image",
|
||||||
WanCameraEmbedding = WanCameraEmbedding
|
"WanCameraEmbedding",
|
||||||
Webcam = Webcam
|
"Webcam",
|
||||||
Mask = Mask
|
"Mask",
|
||||||
Latent = Latent
|
"Latent",
|
||||||
Conditioning = Conditioning
|
"Conditioning",
|
||||||
Sampler = Sampler
|
"Sampler",
|
||||||
Sigmas = Sigmas
|
"Sigmas",
|
||||||
Noise = Noise
|
"Noise",
|
||||||
Guider = Guider
|
"Guider",
|
||||||
Clip = Clip
|
"Clip",
|
||||||
ControlNet = ControlNet
|
"ControlNet",
|
||||||
Vae = Vae
|
"Vae",
|
||||||
Model = Model
|
"Model",
|
||||||
ClipVision = ClipVision
|
"ClipVision",
|
||||||
ClipVisionOutput = ClipVisionOutput
|
"ClipVisionOutput",
|
||||||
AudioEncoder = AudioEncoder
|
"AudioEncoder",
|
||||||
AudioEncoderOutput = AudioEncoderOutput
|
"AudioEncoderOutput",
|
||||||
StyleModel = StyleModel
|
"StyleModel",
|
||||||
Gligen = Gligen
|
"Gligen",
|
||||||
UpscaleModel = UpscaleModel
|
"UpscaleModel",
|
||||||
Audio = Audio
|
"Audio",
|
||||||
Video = Video
|
"Video",
|
||||||
SVG = SVG
|
"SVG",
|
||||||
LoraModel = LoraModel
|
"LoraModel",
|
||||||
LossMap = LossMap
|
"LossMap",
|
||||||
Voxel = Voxel
|
"Voxel",
|
||||||
Mesh = Mesh
|
"Mesh",
|
||||||
Hooks = Hooks
|
"Hooks",
|
||||||
HookKeyframes = HookKeyframes
|
"HookKeyframes",
|
||||||
TimestepsRange = TimestepsRange
|
"TimestepsRange",
|
||||||
LatentOperation = LatentOperation
|
"LatentOperation",
|
||||||
FlowControl = FlowControl
|
"FlowControl",
|
||||||
Accumulation = Accumulation
|
"Accumulation",
|
||||||
Load3DCamera = Load3DCamera
|
"Load3DCamera",
|
||||||
Load3D = Load3D
|
"Load3D",
|
||||||
Load3DAnimation = Load3DAnimation
|
"Load3DAnimation",
|
||||||
Photomaker = Photomaker
|
"Photomaker",
|
||||||
Point = Point
|
"Point",
|
||||||
FaceAnalysis = FaceAnalysis
|
"FaceAnalysis",
|
||||||
BBOX = BBOX
|
"BBOX",
|
||||||
SEGS = SEGS
|
"SEGS",
|
||||||
AnyType = AnyType
|
"AnyType",
|
||||||
MultiType = MultiType
|
"MultiType",
|
||||||
#---------------------------------
|
# Other classes
|
||||||
HiddenHolder = HiddenHolder
|
"HiddenHolder",
|
||||||
Hidden = Hidden
|
"Hidden",
|
||||||
NodeInfoV1 = NodeInfoV1
|
"NodeInfoV1",
|
||||||
NodeInfoV3 = NodeInfoV3
|
"NodeInfoV3",
|
||||||
Schema = Schema
|
"Schema",
|
||||||
ComfyNode = ComfyNode
|
"ComfyNode",
|
||||||
NodeOutput = NodeOutput
|
"NodeOutput",
|
||||||
add_to_dict_v1 = staticmethod(add_to_dict_v1)
|
"add_to_dict_v1",
|
||||||
add_to_dict_v3 = staticmethod(add_to_dict_v3)
|
"add_to_dict_v3",
|
||||||
|
]
|
||||||
|
|||||||
@ -449,15 +449,16 @@ class PreviewText(_UIOutput):
|
|||||||
return {"text": (self.value,)}
|
return {"text": (self.value,)}
|
||||||
|
|
||||||
|
|
||||||
class _UI:
|
__all__ = [
|
||||||
SavedResult = SavedResult
|
"SavedResult",
|
||||||
SavedImages = SavedImages
|
"SavedImages",
|
||||||
SavedAudios = SavedAudios
|
"SavedAudios",
|
||||||
ImageSaveHelper = ImageSaveHelper
|
"ImageSaveHelper",
|
||||||
AudioSaveHelper = AudioSaveHelper
|
"AudioSaveHelper",
|
||||||
PreviewImage = PreviewImage
|
"PreviewImage",
|
||||||
PreviewMask = PreviewMask
|
"PreviewMask",
|
||||||
PreviewAudio = PreviewAudio
|
"PreviewAudio",
|
||||||
PreviewVideo = PreviewVideo
|
"PreviewVideo",
|
||||||
PreviewUI3D = PreviewUI3D
|
"PreviewUI3D",
|
||||||
PreviewText = PreviewText
|
"PreviewText",
|
||||||
|
]
|
||||||
|
|||||||
@ -269,7 +269,7 @@ def tensor_to_bytesio(
|
|||||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Named BytesIO object containing the image data.
|
Named BytesIO object containing the image data, with pointer set to the start of buffer.
|
||||||
"""
|
"""
|
||||||
if not mime_type:
|
if not mime_type:
|
||||||
mime_type = "image/png"
|
mime_type = "image/png"
|
||||||
|
|||||||
@ -98,7 +98,7 @@ import io
|
|||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
from aiohttp.client_exceptions import ClientError, ClientResponseError
|
from aiohttp.client_exceptions import ClientError, ClientResponseError
|
||||||
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
|
from typing import Type, Optional, Any, TypeVar, Generic, Callable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import json
|
import json
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
@ -175,7 +175,7 @@ class ApiClient:
|
|||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
retry_backoff_factor: float = 2.0,
|
retry_backoff_factor: float = 2.0,
|
||||||
retry_status_codes: Optional[Tuple[int, ...]] = None,
|
retry_status_codes: Optional[tuple[int, ...]] = None,
|
||||||
session: Optional[aiohttp.ClientSession] = None,
|
session: Optional[aiohttp.ClientSession] = None,
|
||||||
):
|
):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
@ -199,9 +199,9 @@ class ApiClient:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_json_payload_args(
|
def _create_json_payload_args(
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: Optional[dict[str, Any]] = None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"json": data,
|
"json": data,
|
||||||
"headers": headers,
|
"headers": headers,
|
||||||
@ -209,11 +209,11 @@ class ApiClient:
|
|||||||
|
|
||||||
def _create_form_data_args(
|
def _create_form_data_args(
|
||||||
self,
|
self,
|
||||||
data: Dict[str, Any] | None,
|
data: dict[str, Any] | None,
|
||||||
files: Dict[str, Any] | None,
|
files: dict[str, Any] | None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
multipart_parser: Callable | None = None,
|
multipart_parser: Callable | None = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if headers and "Content-Type" in headers:
|
if headers and "Content-Type" in headers:
|
||||||
del headers["Content-Type"]
|
del headers["Content-Type"]
|
||||||
|
|
||||||
@ -254,9 +254,9 @@ class ApiClient:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_urlencoded_form_data_args(
|
def _create_urlencoded_form_data_args(
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
headers = headers or {}
|
headers = headers or {}
|
||||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||||
return {
|
return {
|
||||||
@ -264,7 +264,7 @@ class ApiClient:
|
|||||||
"headers": headers,
|
"headers": headers,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_headers(self) -> Dict[str, str]:
|
def get_headers(self) -> dict[str, str]:
|
||||||
"""Get headers for API requests, including authentication if available"""
|
"""Get headers for API requests, including authentication if available"""
|
||||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||||
|
|
||||||
@ -275,7 +275,7 @@ class ApiClient:
|
|||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
|
async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
|
||||||
"""
|
"""
|
||||||
Check connectivity to determine if network issues are local or server-related.
|
Check connectivity to determine if network issues are local or server-related.
|
||||||
|
|
||||||
@ -316,14 +316,14 @@ class ApiClient:
|
|||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
path: str,
|
path: str,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[dict[str, Any]] = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: Optional[dict[str, Any]] = None,
|
||||||
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
|
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
multipart_parser: Callable | None = None,
|
multipart_parser: Callable | None = None,
|
||||||
retry_count: int = 0, # Used internally for tracking retries
|
retry_count: int = 0, # Used internally for tracking retries
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Make an HTTP request to the API with automatic retries for transient errors.
|
Make an HTTP request to the API with automatic retries for transient errors.
|
||||||
|
|
||||||
@ -485,7 +485,7 @@ class ApiClient:
|
|||||||
retry_delay: Initial delay between retries in seconds
|
retry_delay: Initial delay between retries in seconds
|
||||||
retry_backoff_factor: Multiplier for the delay after each retry
|
retry_backoff_factor: Multiplier for the delay after each retry
|
||||||
"""
|
"""
|
||||||
headers: Dict[str, str] = {}
|
headers: dict[str, str] = {}
|
||||||
skip_auto_headers: set[str] = set()
|
skip_auto_headers: set[str] = set()
|
||||||
if content_type:
|
if content_type:
|
||||||
headers["Content-Type"] = content_type
|
headers["Content-Type"] = content_type
|
||||||
@ -558,7 +558,7 @@ class ApiClient:
|
|||||||
*req_meta,
|
*req_meta,
|
||||||
retry_count: int,
|
retry_count: int,
|
||||||
response_content: dict | str = "",
|
response_content: dict | str = "",
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
status_code = exc.status
|
status_code = exc.status
|
||||||
if status_code == 401:
|
if status_code == 401:
|
||||||
user_friendly = "Unauthorized: Please login first to use this node."
|
user_friendly = "Unauthorized: Please login first to use this node."
|
||||||
@ -659,7 +659,7 @@ class ApiEndpoint(Generic[T, R]):
|
|||||||
method: HttpMethod,
|
method: HttpMethod,
|
||||||
request_model: Type[T],
|
request_model: Type[T],
|
||||||
response_model: Type[R],
|
response_model: Type[R],
|
||||||
query_params: Optional[Dict[str, Any]] = None,
|
query_params: Optional[dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""Initialize an API endpoint definition.
|
"""Initialize an API endpoint definition.
|
||||||
|
|
||||||
@ -684,11 +684,11 @@ class SynchronousOperation(Generic[T, R]):
|
|||||||
self,
|
self,
|
||||||
endpoint: ApiEndpoint[T, R],
|
endpoint: ApiEndpoint[T, R],
|
||||||
request: T,
|
request: T,
|
||||||
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
|
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
|
||||||
api_base: str | None = None,
|
api_base: str | None = None,
|
||||||
auth_token: Optional[str] = None,
|
auth_token: Optional[str] = None,
|
||||||
comfy_api_key: Optional[str] = None,
|
comfy_api_key: Optional[str] = None,
|
||||||
auth_kwargs: Optional[Dict[str, str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
timeout: float = 7200.0,
|
timeout: float = 7200.0,
|
||||||
verify_ssl: bool = True,
|
verify_ssl: bool = True,
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
@ -729,7 +729,7 @@ class SynchronousOperation(Generic[T, R]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
request_dict: Optional[Dict[str, Any]]
|
request_dict: Optional[dict[str, Any]]
|
||||||
if isinstance(self.request, EmptyRequest):
|
if isinstance(self.request, EmptyRequest):
|
||||||
request_dict = None
|
request_dict = None
|
||||||
else:
|
else:
|
||||||
@ -782,14 +782,14 @@ class PollingOperation(Generic[T, R]):
|
|||||||
poll_endpoint: ApiEndpoint[EmptyRequest, R],
|
poll_endpoint: ApiEndpoint[EmptyRequest, R],
|
||||||
completed_statuses: list[str],
|
completed_statuses: list[str],
|
||||||
failed_statuses: list[str],
|
failed_statuses: list[str],
|
||||||
status_extractor: Callable[[R], str],
|
status_extractor: Callable[[R], Optional[str]],
|
||||||
progress_extractor: Callable[[R], float] | None = None,
|
progress_extractor: Callable[[R], Optional[float]] | None = None,
|
||||||
result_url_extractor: Callable[[R], str] | None = None,
|
result_url_extractor: Callable[[R], Optional[str]] | None = None,
|
||||||
request: Optional[T] = None,
|
request: Optional[T] = None,
|
||||||
api_base: str | None = None,
|
api_base: str | None = None,
|
||||||
auth_token: Optional[str] = None,
|
auth_token: Optional[str] = None,
|
||||||
comfy_api_key: Optional[str] = None,
|
comfy_api_key: Optional[str] = None,
|
||||||
auth_kwargs: Optional[Dict[str, str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
poll_interval: float = 5.0,
|
poll_interval: float = 5.0,
|
||||||
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
|
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
|
||||||
max_retries: int = 3, # Max retries per individual API call
|
max_retries: int = 3, # Max retries per individual API call
|
||||||
|
|||||||
100
comfy_api_nodes/apis/pika_defs.py
Normal file
100
comfy_api_nodes/apis/pika_defs.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Pikaffect(str, Enum):
|
||||||
|
Cake_ify = "Cake-ify"
|
||||||
|
Crumble = "Crumble"
|
||||||
|
Crush = "Crush"
|
||||||
|
Decapitate = "Decapitate"
|
||||||
|
Deflate = "Deflate"
|
||||||
|
Dissolve = "Dissolve"
|
||||||
|
Explode = "Explode"
|
||||||
|
Eye_pop = "Eye-pop"
|
||||||
|
Inflate = "Inflate"
|
||||||
|
Levitate = "Levitate"
|
||||||
|
Melt = "Melt"
|
||||||
|
Peel = "Peel"
|
||||||
|
Poke = "Poke"
|
||||||
|
Squish = "Squish"
|
||||||
|
Ta_da = "Ta-da"
|
||||||
|
Tear = "Tear"
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
|
||||||
|
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
|
||||||
|
duration: Optional[int] = Field(5)
|
||||||
|
ingredientsMode: str = Field(...)
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
resolution: Optional[str] = Field('1080p')
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaGenerateResponse(BaseModel):
|
||||||
|
video_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
|
||||||
|
duration: Optional[int] = 5
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
resolution: Optional[str] = '1080p'
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
|
||||||
|
duration: Optional[int] = Field(None, ge=5, le=10)
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: str = Field(...)
|
||||||
|
resolution: Optional[str] = '1080p'
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
|
||||||
|
aspectRatio: Optional[float] = Field(
|
||||||
|
1.7777777777777777,
|
||||||
|
description='Aspect ratio (width / height)',
|
||||||
|
ge=0.4,
|
||||||
|
le=2.5,
|
||||||
|
)
|
||||||
|
duration: Optional[int] = 5
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: str = Field(...)
|
||||||
|
resolution: Optional[str] = '1080p'
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
pikaffect: Optional[str] = None
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
modifyRegionRoi: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaStatusEnum(str, Enum):
|
||||||
|
queued = "queued"
|
||||||
|
started = "started"
|
||||||
|
finished = "finished"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class PikaVideoResponse(BaseModel):
|
||||||
|
id: str = Field(...)
|
||||||
|
progress: Optional[int] = Field(None)
|
||||||
|
status: PikaStatusEnum
|
||||||
|
url: Optional[str] = Field(None)
|
||||||
@ -8,30 +8,17 @@ from __future__ import annotations
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, TypeVar
|
from typing import Optional, TypeVar
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
from comfy_api.latest import ComfyExtension, comfy_io
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||||
from comfy_api_nodes.apinode_utils import (
|
from comfy_api_nodes.apinode_utils import (
|
||||||
download_url_to_video_output,
|
download_url_to_video_output,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import pika_defs
|
||||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
|
||||||
PikaBodyGenerate22I2vGenerate22I2vPost,
|
|
||||||
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
|
||||||
PikaBodyGenerate22T2vGenerate22T2vPost,
|
|
||||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
|
||||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
|
||||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
|
||||||
PikaGenerateResponse,
|
|
||||||
PikaVideoResponse,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.apis.client import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
EmptyRequest,
|
EmptyRequest,
|
||||||
@ -55,116 +42,36 @@ PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
|
|||||||
PATH_VIDEO_GET = "/proxy/pika/videos"
|
PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||||
|
|
||||||
|
|
||||||
class PikaDurationEnum(int, Enum):
|
async def execute_task(
|
||||||
integer_5 = 5
|
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
|
||||||
integer_10 = 10
|
|
||||||
|
|
||||||
|
|
||||||
class PikaResolutionEnum(str, Enum):
|
|
||||||
field_1080p = "1080p"
|
|
||||||
field_720p = "720p"
|
|
||||||
|
|
||||||
|
|
||||||
class Pikaffect(str, Enum):
|
|
||||||
Cake_ify = "Cake-ify"
|
|
||||||
Crumble = "Crumble"
|
|
||||||
Crush = "Crush"
|
|
||||||
Decapitate = "Decapitate"
|
|
||||||
Deflate = "Deflate"
|
|
||||||
Dissolve = "Dissolve"
|
|
||||||
Explode = "Explode"
|
|
||||||
Eye_pop = "Eye-pop"
|
|
||||||
Inflate = "Inflate"
|
|
||||||
Levitate = "Levitate"
|
|
||||||
Melt = "Melt"
|
|
||||||
Peel = "Peel"
|
|
||||||
Poke = "Poke"
|
|
||||||
Squish = "Squish"
|
|
||||||
Ta_da = "Ta-da"
|
|
||||||
Tear = "Tear"
|
|
||||||
|
|
||||||
|
|
||||||
class PikaApiError(Exception):
|
|
||||||
"""Exception for Pika API errors."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_video_response(response: PikaVideoResponse) -> bool:
|
|
||||||
"""Check if the video response is valid."""
|
|
||||||
return hasattr(response, "url") and response.url is not None
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
|
|
||||||
"""Check if the initial response is valid."""
|
|
||||||
return hasattr(response, "video_id") and response.video_id is not None
|
|
||||||
|
|
||||||
|
|
||||||
async def poll_for_task_status(
|
|
||||||
task_id: str,
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
) -> PikaGenerateResponse:
|
) -> comfy_io.NodeOutput:
|
||||||
polling_operation = PollingOperation(
|
task_id = (await initial_operation.execute()).video_id
|
||||||
|
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
|
||||||
poll_endpoint=ApiEndpoint(
|
poll_endpoint=ApiEndpoint(
|
||||||
path=f"{PATH_VIDEO_GET}/{task_id}",
|
path=f"{PATH_VIDEO_GET}/{task_id}",
|
||||||
method=HttpMethod.GET,
|
method=HttpMethod.GET,
|
||||||
request_model=EmptyRequest,
|
request_model=EmptyRequest,
|
||||||
response_model=PikaVideoResponse,
|
response_model=pika_defs.PikaVideoResponse,
|
||||||
),
|
),
|
||||||
completed_statuses=[
|
completed_statuses=["finished"],
|
||||||
"finished",
|
|
||||||
],
|
|
||||||
failed_statuses=["failed", "cancelled"],
|
failed_statuses=["failed", "cancelled"],
|
||||||
status_extractor=lambda response: (
|
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||||
response.status.value if response.status else None
|
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||||
),
|
|
||||||
progress_extractor=lambda response: (
|
|
||||||
response.progress if hasattr(response, "progress") else None
|
|
||||||
),
|
|
||||||
auth_kwargs=auth_kwargs,
|
auth_kwargs=auth_kwargs,
|
||||||
result_url_extractor=lambda response: (
|
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
|
||||||
response.url if hasattr(response, "url") else None
|
|
||||||
),
|
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
estimated_duration=60
|
estimated_duration=60,
|
||||||
)
|
max_poll_attempts=240,
|
||||||
return await polling_operation.execute()
|
).execute()
|
||||||
|
if not final_response.url:
|
||||||
|
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||||
async def execute_task(
|
|
||||||
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
) -> tuple[VideoFromFile]:
|
|
||||||
"""Executes the initial operation then polls for the task status until it is completed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
initial_operation: The initial operation to execute.
|
|
||||||
auth_kwargs: The authentication token(s) to use for the API call.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing the video file as a VIDEO output.
|
|
||||||
"""
|
|
||||||
initial_response = await initial_operation.execute()
|
|
||||||
if not is_valid_initial_response(initial_response):
|
|
||||||
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
|
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise PikaApiError(error_msg)
|
raise Exception(error_msg)
|
||||||
|
video_url = final_response.url
|
||||||
task_id = initial_response.video_id
|
|
||||||
final_response = await poll_for_task_status(task_id, auth_kwargs, node_id=node_id)
|
|
||||||
if not is_valid_video_response(final_response):
|
|
||||||
error_msg = (
|
|
||||||
f"Pika task {task_id} succeeded but no video data found in response."
|
|
||||||
)
|
|
||||||
logging.error(error_msg)
|
|
||||||
raise PikaApiError(error_msg)
|
|
||||||
|
|
||||||
video_url = str(final_response.url)
|
|
||||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||||
|
return comfy_io.NodeOutput(await download_url_to_video_output(video_url))
|
||||||
return (await download_url_to_video_output(video_url),)
|
|
||||||
|
|
||||||
|
|
||||||
def get_base_inputs_types() -> list[comfy_io.Input]:
|
def get_base_inputs_types() -> list[comfy_io.Input]:
|
||||||
@ -173,16 +80,12 @@ def get_base_inputs_types() -> list[comfy_io.Input]:
|
|||||||
comfy_io.String.Input("prompt_text", multiline=True),
|
comfy_io.String.Input("prompt_text", multiline=True),
|
||||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
comfy_io.String.Input("negative_prompt", multiline=True),
|
||||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||||
comfy_io.Combo.Input(
|
comfy_io.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
|
||||||
"resolution", options=PikaResolutionEnum, default=PikaResolutionEnum.field_1080p
|
comfy_io.Combo.Input("duration", options=[5, 10], default=5),
|
||||||
),
|
|
||||||
comfy_io.Combo.Input(
|
|
||||||
"duration", options=PikaDurationEnum, default=PikaDurationEnum.integer_5
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class PikaImageToVideoV2_2(comfy_io.ComfyNode):
|
class PikaImageToVideo(comfy_io.ComfyNode):
|
||||||
"""Pika 2.2 Image to Video Node."""
|
"""Pika 2.2 Image to Video Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -215,14 +118,9 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
|
|||||||
resolution: str,
|
resolution: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> comfy_io.NodeOutput:
|
||||||
# Convert image to BytesIO
|
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
image_bytes_io.seek(0)
|
|
||||||
|
|
||||||
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
||||||
|
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||||
# Prepare non-file data
|
|
||||||
pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -237,8 +135,8 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_IMAGE_TO_VIDEO,
|
path=PATH_IMAGE_TO_VIDEO,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGenerate22I2vGenerate22I2vPost,
|
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=pika_request_data,
|
request=pika_request_data,
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
@ -248,7 +146,7 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
|
|||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||||
|
|
||||||
|
|
||||||
class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
class PikaTextToVideoNode(comfy_io.ComfyNode):
|
||||||
"""Pika Text2Video v2.2 Node."""
|
"""Pika Text2Video v2.2 Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -296,10 +194,10 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_TEXT_TO_VIDEO,
|
path=PATH_TEXT_TO_VIDEO,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGenerate22T2vGenerate22T2vPost,
|
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=PikaBodyGenerate22T2vGenerate22T2vPost(
|
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -313,7 +211,7 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
|||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||||
|
|
||||||
|
|
||||||
class PikaScenesV2_2(comfy_io.ComfyNode):
|
class PikaScenes(comfy_io.ComfyNode):
|
||||||
"""PikaScenes v2.2 Node."""
|
"""PikaScenes v2.2 Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -389,7 +287,6 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
|||||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> comfy_io.NodeOutput:
|
||||||
# Convert all passed images to BytesIO
|
|
||||||
all_image_bytes_io = []
|
all_image_bytes_io = []
|
||||||
for image in [
|
for image in [
|
||||||
image_ingredient_1,
|
image_ingredient_1,
|
||||||
@ -399,16 +296,14 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
|||||||
image_ingredient_5,
|
image_ingredient_5,
|
||||||
]:
|
]:
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
all_image_bytes_io.append(tensor_to_bytesio(image))
|
||||||
image_bytes_io.seek(0)
|
|
||||||
all_image_bytes_io.append(image_bytes_io)
|
|
||||||
|
|
||||||
pika_files = [
|
pika_files = [
|
||||||
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
||||||
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
||||||
]
|
]
|
||||||
|
|
||||||
pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
||||||
ingredientsMode=ingredients_mode,
|
ingredientsMode=ingredients_mode,
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
@ -425,8 +320,8 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_PIKASCENES,
|
path=PATH_PIKASCENES,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=pika_request_data,
|
request=pika_request_data,
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
@ -477,22 +372,16 @@ class PikAdditionsNode(comfy_io.ComfyNode):
|
|||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> comfy_io.NodeOutput:
|
||||||
# Convert video to BytesIO
|
|
||||||
video_bytes_io = BytesIO()
|
video_bytes_io = BytesIO()
|
||||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||||
video_bytes_io.seek(0)
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
# Convert image to BytesIO
|
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
image_bytes_io.seek(0)
|
|
||||||
|
|
||||||
pika_files = {
|
pika_files = {
|
||||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||||
"image": ("image.png", image_bytes_io, "image/png"),
|
"image": ("image.png", image_bytes_io, "image/png"),
|
||||||
}
|
}
|
||||||
|
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||||
# Prepare non-file data
|
|
||||||
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -505,8 +394,8 @@ class PikAdditionsNode(comfy_io.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_PIKADDITIONS,
|
path=PATH_PIKADDITIONS,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=pika_request_data,
|
request=pika_request_data,
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
@ -529,11 +418,25 @@ class PikaSwapsNode(comfy_io.ComfyNode):
|
|||||||
category="api node/video/Pika",
|
category="api node/video/Pika",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Video.Input("video", tooltip="The video to swap an object in."),
|
comfy_io.Video.Input("video", tooltip="The video to swap an object in."),
|
||||||
comfy_io.Image.Input("image", tooltip="The image used to replace the masked object in the video."),
|
comfy_io.Image.Input(
|
||||||
comfy_io.Mask.Input("mask", tooltip="Use the mask to define areas in the video to replace"),
|
"image",
|
||||||
comfy_io.String.Input("prompt_text", multiline=True),
|
tooltip="The image used to replace the masked object in the video.",
|
||||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
optional=True,
|
||||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
),
|
||||||
|
comfy_io.Mask.Input(
|
||||||
|
"mask",
|
||||||
|
tooltip="Use the mask to define areas in the video to replace.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.String.Input("prompt_text", multiline=True, optional=True),
|
||||||
|
comfy_io.String.Input("negative_prompt", multiline=True, optional=True),
|
||||||
|
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"region_to_modify",
|
||||||
|
multiline=True,
|
||||||
|
optional=True,
|
||||||
|
tooltip="Plaintext description of the object / region to modify.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[comfy_io.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
@ -548,41 +451,29 @@ class PikaSwapsNode(comfy_io.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
video: VideoInput,
|
video: VideoInput,
|
||||||
image: torch.Tensor,
|
image: Optional[torch.Tensor] = None,
|
||||||
mask: torch.Tensor,
|
mask: Optional[torch.Tensor] = None,
|
||||||
prompt_text: str,
|
prompt_text: str = "",
|
||||||
negative_prompt: str,
|
negative_prompt: str = "",
|
||||||
seed: int,
|
seed: int = 0,
|
||||||
|
region_to_modify: str = "",
|
||||||
) -> comfy_io.NodeOutput:
|
) -> comfy_io.NodeOutput:
|
||||||
# Convert video to BytesIO
|
|
||||||
video_bytes_io = BytesIO()
|
video_bytes_io = BytesIO()
|
||||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||||
video_bytes_io.seek(0)
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
# Convert mask to binary mask with three channels
|
|
||||||
mask = torch.round(mask)
|
|
||||||
mask = mask.repeat(1, 3, 1, 1)
|
|
||||||
|
|
||||||
# Convert 3-channel binary mask to BytesIO
|
|
||||||
mask_bytes_io = BytesIO()
|
|
||||||
mask_bytes_io.write(mask.numpy().astype(np.uint8))
|
|
||||||
mask_bytes_io.seek(0)
|
|
||||||
|
|
||||||
# Convert image to BytesIO
|
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
|
||||||
image_bytes_io.seek(0)
|
|
||||||
|
|
||||||
pika_files = {
|
pika_files = {
|
||||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||||
"image": ("image.png", image_bytes_io, "image/png"),
|
|
||||||
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
|
|
||||||
}
|
}
|
||||||
|
if mask is not None:
|
||||||
|
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
|
||||||
|
if image is not None:
|
||||||
|
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
|
||||||
|
|
||||||
# Prepare non-file data
|
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||||
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||||
)
|
)
|
||||||
auth = {
|
auth = {
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
@ -590,10 +481,10 @@ class PikaSwapsNode(comfy_io.ComfyNode):
|
|||||||
}
|
}
|
||||||
initial_operation = SynchronousOperation(
|
initial_operation = SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_PIKADDITIONS,
|
path=PATH_PIKASWAPS,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=pika_request_data,
|
request=pika_request_data,
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
@ -616,7 +507,7 @@ class PikaffectsNode(comfy_io.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
||||||
comfy_io.Combo.Input(
|
comfy_io.Combo.Input(
|
||||||
"pikaffect", options=Pikaffect, default="Cake-ify"
|
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
|
||||||
),
|
),
|
||||||
comfy_io.String.Input("prompt_text", multiline=True),
|
comfy_io.String.Input("prompt_text", multiline=True),
|
||||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
comfy_io.String.Input("negative_prompt", multiline=True),
|
||||||
@ -648,10 +539,10 @@ class PikaffectsNode(comfy_io.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_PIKAFFECTS,
|
path=PATH_PIKAFFECTS,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||||
pikaffect=pikaffect,
|
pikaffect=pikaffect,
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
@ -664,7 +555,7 @@ class PikaffectsNode(comfy_io.ComfyNode):
|
|||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||||
|
|
||||||
|
|
||||||
class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
|
class PikaStartEndFrameNode(comfy_io.ComfyNode):
|
||||||
"""PikaFrames v2.2 Node."""
|
"""PikaFrames v2.2 Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -711,10 +602,10 @@ class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_PIKAFRAMES,
|
path=PATH_PIKAFRAMES,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -732,13 +623,13 @@ class PikaApiNodesExtension(ComfyExtension):
|
|||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
PikaImageToVideoV2_2,
|
PikaImageToVideo,
|
||||||
PikaTextToVideoNodeV2_2,
|
PikaTextToVideoNode,
|
||||||
PikaScenesV2_2,
|
PikaScenes,
|
||||||
PikAdditionsNode,
|
PikAdditionsNode,
|
||||||
PikaSwapsNode,
|
PikaSwapsNode,
|
||||||
PikaffectsNode,
|
PikaffectsNode,
|
||||||
PikaStartEndFrameNode2_2,
|
PikaStartEndFrameNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,60 +1,80 @@
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
class CLIPTextEncodeFlux:
|
|
||||||
|
class CLIPTextEncodeFlux(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"clip": ("CLIP", ),
|
node_id="CLIPTextEncodeFlux",
|
||||||
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
category="advanced/conditioning/flux",
|
||||||
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
inputs=[
|
||||||
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
|
io.Clip.Input("clip"),
|
||||||
}}
|
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
||||||
FUNCTION = "encode"
|
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
@classmethod
|
||||||
|
def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput:
|
||||||
def encode(self, clip, clip_l, t5xxl, guidance):
|
|
||||||
tokens = clip.tokenize(clip_l)
|
tokens = clip.tokenize(clip_l)
|
||||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||||
|
|
||||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
|
||||||
|
|
||||||
class FluxGuidance:
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class FluxGuidance(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"conditioning": ("CONDITIONING", ),
|
node_id="FluxGuidance",
|
||||||
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
|
category="advanced/conditioning/flux",
|
||||||
}}
|
inputs=[
|
||||||
|
io.Conditioning.Input("conditioning"),
|
||||||
|
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
@classmethod
|
||||||
FUNCTION = "append"
|
def execute(cls, conditioning, guidance) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
|
||||||
|
|
||||||
def append(self, conditioning, guidance):
|
|
||||||
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
|
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
|
||||||
return (c, )
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
|
append = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class FluxDisableGuidance:
|
class FluxDisableGuidance(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"conditioning": ("CONDITIONING", ),
|
node_id="FluxDisableGuidance",
|
||||||
}}
|
category="advanced/conditioning/flux",
|
||||||
|
description="This node completely disables the guidance embed on Flux and Flux like models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("conditioning"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
@classmethod
|
||||||
FUNCTION = "append"
|
def execute(cls, conditioning) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
|
||||||
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
|
|
||||||
|
|
||||||
def append(self, conditioning):
|
|
||||||
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
|
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
|
||||||
return (c, )
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
|
append = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||||
@ -78,52 +98,73 @@ PREFERED_KONTEXT_RESOLUTIONS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class FluxKontextImageScale:
|
class FluxKontextImageScale(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"image": ("IMAGE", ),
|
return io.Schema(
|
||||||
},
|
node_id="FluxKontextImageScale",
|
||||||
}
|
category="advanced/conditioning/flux",
|
||||||
|
description="This node resizes the image to one that is more optimal for flux kontext.",
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
@classmethod
|
||||||
FUNCTION = "scale"
|
def execute(cls, image) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
|
||||||
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
|
|
||||||
|
|
||||||
def scale(self, image):
|
|
||||||
width = image.shape[2]
|
width = image.shape[2]
|
||||||
height = image.shape[1]
|
height = image.shape[1]
|
||||||
aspect_ratio = width / height
|
aspect_ratio = width / height
|
||||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
||||||
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
||||||
return (image, )
|
return io.NodeOutput(image)
|
||||||
|
|
||||||
|
scale = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class FluxKontextMultiReferenceLatentMethod:
|
class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"conditioning": ("CONDITIONING", ),
|
node_id="FluxKontextMultiReferenceLatentMethod",
|
||||||
"reference_latents_method": (("offset", "index", "uxo/uno"), ),
|
category="advanced/conditioning/flux",
|
||||||
}}
|
inputs=[
|
||||||
|
io.Conditioning.Input("conditioning"),
|
||||||
|
io.Combo.Input(
|
||||||
|
"reference_latents_method",
|
||||||
|
options=["offset", "index", "uxo/uno"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
@classmethod
|
||||||
FUNCTION = "append"
|
def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput:
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
|
||||||
|
|
||||||
def append(self, conditioning, reference_latents_method):
|
|
||||||
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
|
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
|
||||||
reference_latents_method = "uxo"
|
reference_latents_method = "uxo"
|
||||||
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
|
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
|
||||||
return (c, )
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
append = execute # TODO: remove
|
||||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
|
||||||
"FluxGuidance": FluxGuidance,
|
|
||||||
"FluxDisableGuidance": FluxDisableGuidance,
|
class FluxExtension(ComfyExtension):
|
||||||
"FluxKontextImageScale": FluxKontextImageScale,
|
@override
|
||||||
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
}
|
return [
|
||||||
|
CLIPTextEncodeFlux,
|
||||||
|
FluxGuidance,
|
||||||
|
FluxDisableGuidance,
|
||||||
|
FluxKontextImageScale,
|
||||||
|
FluxKontextMultiReferenceLatentMethod,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> FluxExtension:
|
||||||
|
return FluxExtension()
|
||||||
|
|||||||
@ -3,64 +3,83 @@ import comfy.sd
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import nodes
|
import nodes
|
||||||
import torch
|
import torch
|
||||||
import comfy_extras.nodes_slg
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_extras.nodes_slg import SkipLayerGuidanceDiT
|
||||||
|
|
||||||
|
|
||||||
class TripleCLIPLoader:
|
class TripleCLIPLoader(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), )
|
return io.Schema(
|
||||||
}}
|
node_id="TripleCLIPLoader",
|
||||||
RETURN_TYPES = ("CLIP",)
|
category="advanced/loaders",
|
||||||
FUNCTION = "load_clip"
|
description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
|
||||||
|
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
|
||||||
|
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Clip.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
@classmethod
|
||||||
|
def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput:
|
||||||
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
|
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
|
||||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return io.NodeOutput(clip)
|
||||||
|
|
||||||
|
load_clip = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class EmptySD3LatentImage:
|
class EmptySD3LatentImage(io.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.device = comfy.model_management.intermediate_device()
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="EmptySD3LatentImage",
|
||||||
|
category="latent/sd3",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
|
||||||
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
return io.NodeOutput({"samples":latent})
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
|
||||||
FUNCTION = "generate"
|
|
||||||
|
|
||||||
CATEGORY = "latent/sd3"
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1):
|
|
||||||
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
|
||||||
return ({"samples":latent}, )
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncodeSD3:
|
class CLIPTextEncodeSD3(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"clip": ("CLIP", ),
|
node_id="CLIPTextEncodeSD3",
|
||||||
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
category="advanced/conditioning",
|
||||||
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
inputs=[
|
||||||
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
io.Clip.Input("clip"),
|
||||||
"empty_padding": (["none", "empty_prompt"], )
|
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||||
}}
|
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
||||||
FUNCTION = "encode"
|
io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning"
|
@classmethod
|
||||||
|
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput:
|
||||||
def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
|
|
||||||
no_padding = empty_padding == "none"
|
no_padding = empty_padding == "none"
|
||||||
|
|
||||||
tokens = clip.tokenize(clip_g)
|
tokens = clip.tokenize(clip_g)
|
||||||
@ -82,57 +101,112 @@ class CLIPTextEncodeSD3:
|
|||||||
tokens["l"] += empty["l"]
|
tokens["l"] += empty["l"]
|
||||||
while len(tokens["l"]) > len(tokens["g"]):
|
while len(tokens["l"]) > len(tokens["g"]):
|
||||||
tokens["g"] += empty["g"]
|
tokens["g"] += empty["g"]
|
||||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
|
||||||
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
class ControlNetApplySD3(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": {"positive": ("CONDITIONING", ),
|
return io.Schema(
|
||||||
"negative": ("CONDITIONING", ),
|
node_id="ControlNetApplySD3",
|
||||||
"control_net": ("CONTROL_NET", ),
|
display_name="Apply Controlnet with VAE",
|
||||||
"vae": ("VAE", ),
|
category="conditioning/controlnet",
|
||||||
"image": ("IMAGE", ),
|
inputs=[
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Conditioning.Input("positive"),
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
io.Conditioning.Input("negative"),
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
io.ControlNet.Input("control_net"),
|
||||||
}}
|
io.Vae.Input("vae"),
|
||||||
CATEGORY = "conditioning/controlnet"
|
io.Image.Input("image"),
|
||||||
DEPRECATED = True
|
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
],
|
||||||
|
is_deprecated=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput:
|
||||||
|
if strength == 0:
|
||||||
|
return io.NodeOutput(positive, negative)
|
||||||
|
|
||||||
|
control_hint = image.movedim(-1, 1)
|
||||||
|
cnets = {}
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for conditioning in [positive, negative]:
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
d = t[1].copy()
|
||||||
|
|
||||||
|
prev_cnet = d.get('control', None)
|
||||||
|
if prev_cnet in cnets:
|
||||||
|
c_net = cnets[prev_cnet]
|
||||||
|
else:
|
||||||
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent),
|
||||||
|
vae=vae, extra_concat=[])
|
||||||
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
|
d['control'] = c_net
|
||||||
|
d['control_apply_to_uncond'] = False
|
||||||
|
n = [t[0], d]
|
||||||
|
c.append(n)
|
||||||
|
out.append(c)
|
||||||
|
return io.NodeOutput(out[0], out[1])
|
||||||
|
|
||||||
|
apply_controlnet = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
|
class SkipLayerGuidanceSD3(io.ComfyNode):
|
||||||
'''
|
'''
|
||||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||||
Experimental implementation by Dango233@StabilityAI.
|
Experimental implementation by Dango233@StabilityAI.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"model": ("MODEL", ),
|
return io.Schema(
|
||||||
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
node_id="SkipLayerGuidanceSD3",
|
||||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
category="advanced/guidance",
|
||||||
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
|
||||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
inputs=[
|
||||||
}}
|
io.Model.Input("model"),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.String.Input("layers", default="7, 8, 9", multiline=False),
|
||||||
FUNCTION = "skip_guidance_sd3"
|
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
|
||||||
|
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/guidance"
|
@classmethod
|
||||||
|
def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput:
|
||||||
|
return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
|
||||||
|
|
||||||
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
|
skip_guidance_sd3 = execute # TODO: remove
|
||||||
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class SD3Extension(ComfyExtension):
|
||||||
"TripleCLIPLoader": TripleCLIPLoader,
|
@override
|
||||||
"EmptySD3LatentImage": EmptySD3LatentImage,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
return [
|
||||||
"ControlNetApplySD3": ControlNetApplySD3,
|
TripleCLIPLoader,
|
||||||
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3,
|
EmptySD3LatentImage,
|
||||||
}
|
CLIPTextEncodeSD3,
|
||||||
|
ControlNetApplySD3,
|
||||||
|
SkipLayerGuidanceSD3,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
# Sampling
|
async def comfy_entrypoint() -> SD3Extension:
|
||||||
"ControlNetApplySD3": "Apply Controlnet with VAE",
|
return SD3Extension()
|
||||||
}
|
|
||||||
|
|||||||
@ -1,33 +1,40 @@
|
|||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import re
|
import re
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
class SkipLayerGuidanceDiT:
|
class SkipLayerGuidanceDiT(io.ComfyNode):
|
||||||
'''
|
'''
|
||||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||||
Original experimental implementation for SD3 by Dango233@StabilityAI.
|
Original experimental implementation for SD3 by Dango233@StabilityAI.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"model": ("MODEL", ),
|
return io.Schema(
|
||||||
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
node_id="SkipLayerGuidanceDiT",
|
||||||
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
category="advanced/guidance",
|
||||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
|
||||||
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
is_experimental=True,
|
||||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
|
inputs=[
|
||||||
"rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Model.Input("model"),
|
||||||
}}
|
io.String.Input("double_layers", default="7, 8, 9"),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.String.Input("single_layers", default="7, 8, 9"),
|
||||||
FUNCTION = "skip_guidance"
|
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
|
||||||
EXPERIMENTAL = True
|
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model."
|
@classmethod
|
||||||
|
def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0) -> io.NodeOutput:
|
||||||
CATEGORY = "advanced/guidance"
|
|
||||||
|
|
||||||
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
|
|
||||||
# check if layer is comma separated integers
|
# check if layer is comma separated integers
|
||||||
def skip(args, extra_args):
|
def skip(args, extra_args):
|
||||||
return args
|
return args
|
||||||
@ -43,7 +50,7 @@ class SkipLayerGuidanceDiT:
|
|||||||
single_layers = [int(i) for i in single_layers]
|
single_layers = [int(i) for i in single_layers]
|
||||||
|
|
||||||
if len(double_layers) == 0 and len(single_layers) == 0:
|
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||||
return (model, )
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
model = args["model"]
|
model = args["model"]
|
||||||
@ -76,29 +83,36 @@ class SkipLayerGuidanceDiT:
|
|||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||||
|
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class SkipLayerGuidanceDiTSimple:
|
skip_guidance = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class SkipLayerGuidanceDiTSimple(io.ComfyNode):
|
||||||
'''
|
'''
|
||||||
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
|
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
|
||||||
'''
|
'''
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"model": ("MODEL", ),
|
return io.Schema(
|
||||||
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
node_id="SkipLayerGuidanceDiTSimple",
|
||||||
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
category="advanced/guidance",
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.",
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
is_experimental=True,
|
||||||
}}
|
inputs=[
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Model.Input("model"),
|
||||||
FUNCTION = "skip_guidance"
|
io.String.Input("double_layers", default="7, 8, 9"),
|
||||||
EXPERIMENTAL = True
|
io.String.Input("single_layers", default="7, 8, 9"),
|
||||||
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass."
|
@classmethod
|
||||||
|
def execute(cls, model, start_percent, end_percent, double_layers="", single_layers="") -> io.NodeOutput:
|
||||||
CATEGORY = "advanced/guidance"
|
|
||||||
|
|
||||||
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
|
|
||||||
def skip(args, extra_args):
|
def skip(args, extra_args):
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -113,7 +127,7 @@ class SkipLayerGuidanceDiTSimple:
|
|||||||
single_layers = [int(i) for i in single_layers]
|
single_layers = [int(i) for i in single_layers]
|
||||||
|
|
||||||
if len(double_layers) == 0 and len(single_layers) == 0:
|
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||||
return (model, )
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
def calc_cond_batch_function(args):
|
def calc_cond_batch_function(args):
|
||||||
x = args["input"]
|
x = args["input"]
|
||||||
@ -144,9 +158,19 @@ class SkipLayerGuidanceDiTSimple:
|
|||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
|
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
|
||||||
|
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
skip_guidance = execute # TODO: remove
|
||||||
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
|
|
||||||
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
|
|
||||||
}
|
class SkipLayerGuidanceExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
SkipLayerGuidanceDiT,
|
||||||
|
SkipLayerGuidanceDiTSimple,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> SkipLayerGuidanceExtension:
|
||||||
|
return SkipLayerGuidanceExtension()
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from comfy import model_management
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from spandrel_extra_arches import EXTRA_REGISTRY
|
from spandrel_extra_arches import EXTRA_REGISTRY
|
||||||
@ -13,17 +15,23 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class UpscaleModelLoader:
|
class UpscaleModelLoader(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ),
|
return io.Schema(
|
||||||
}}
|
node_id="UpscaleModelLoader",
|
||||||
RETURN_TYPES = ("UPSCALE_MODEL",)
|
display_name="Load Upscale Model",
|
||||||
FUNCTION = "load_model"
|
category="loaders",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.UpscaleModel.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
@classmethod
|
||||||
|
def execute(cls, model_name) -> io.NodeOutput:
|
||||||
def load_model(self, model_name):
|
|
||||||
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
||||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||||
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
||||||
@ -33,21 +41,29 @@ class UpscaleModelLoader:
|
|||||||
if not isinstance(out, ImageModelDescriptor):
|
if not isinstance(out, ImageModelDescriptor):
|
||||||
raise Exception("Upscale model must be a single-image model.")
|
raise Exception("Upscale model must be a single-image model.")
|
||||||
|
|
||||||
return (out, )
|
return io.NodeOutput(out)
|
||||||
|
|
||||||
|
load_model = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class ImageUpscaleWithModel:
|
class ImageUpscaleWithModel(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "upscale_model": ("UPSCALE_MODEL",),
|
return io.Schema(
|
||||||
"image": ("IMAGE",),
|
node_id="ImageUpscaleWithModel",
|
||||||
}}
|
display_name="Upscale Image (using Model)",
|
||||||
RETURN_TYPES = ("IMAGE",)
|
category="image/upscaling",
|
||||||
FUNCTION = "upscale"
|
inputs=[
|
||||||
|
io.UpscaleModel.Input("upscale_model"),
|
||||||
|
io.Image.Input("image"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "image/upscaling"
|
@classmethod
|
||||||
|
def execute(cls, upscale_model, image) -> io.NodeOutput:
|
||||||
def upscale(self, upscale_model, image):
|
|
||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
|
|
||||||
memory_required = model_management.module_size(upscale_model.model)
|
memory_required = model_management.module_size(upscale_model.model)
|
||||||
@ -75,9 +91,19 @@ class ImageUpscaleWithModel:
|
|||||||
|
|
||||||
upscale_model.to("cpu")
|
upscale_model.to("cpu")
|
||||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||||
return (s,)
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
upscale = execute # TODO: remove
|
||||||
"UpscaleModelLoader": UpscaleModelLoader,
|
|
||||||
"ImageUpscaleWithModel": ImageUpscaleWithModel
|
|
||||||
}
|
class UpscaleModelExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
UpscaleModelLoader,
|
||||||
|
ImageUpscaleWithModel,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> UpscaleModelExtension:
|
||||||
|
return UpscaleModelExtension()
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -2030,7 +2030,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"DiffControlNetLoader": "Load ControlNet Model (diff)",
|
"DiffControlNetLoader": "Load ControlNet Model (diff)",
|
||||||
"StyleModelLoader": "Load Style Model",
|
"StyleModelLoader": "Load Style Model",
|
||||||
"CLIPVisionLoader": "Load CLIP Vision",
|
"CLIPVisionLoader": "Load CLIP Vision",
|
||||||
"UpscaleModelLoader": "Load Upscale Model",
|
|
||||||
"UNETLoader": "Load Diffusion Model",
|
"UNETLoader": "Load Diffusion Model",
|
||||||
# Conditioning
|
# Conditioning
|
||||||
"CLIPVisionEncode": "CLIP Vision Encode",
|
"CLIPVisionEncode": "CLIP Vision Encode",
|
||||||
@ -2068,7 +2067,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LoadImageOutput": "Load Image (from Outputs)",
|
"LoadImageOutput": "Load Image (from Outputs)",
|
||||||
"ImageScale": "Upscale Image",
|
"ImageScale": "Upscale Image",
|
||||||
"ImageScaleBy": "Upscale Image By",
|
"ImageScaleBy": "Upscale Image By",
|
||||||
"ImageUpscaleWithModel": "Upscale Image (using Model)",
|
|
||||||
"ImageInvert": "Invert Image",
|
"ImageInvert": "Invert Image",
|
||||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||||
"ImageBatch": "Batch Images",
|
"ImageBatch": "Batch Images",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user