mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-22 00:12:34 +08:00
Merge upstream/master, keep local README.md
This commit is contained in:
commit
ff7d9dd57d
@ -123,16 +123,30 @@ def move_weight_functions(m, device):
|
|||||||
return memory
|
return memory
|
||||||
|
|
||||||
class LowVramPatch:
|
class LowVramPatch:
|
||||||
def __init__(self, key, patches):
|
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.patches = patches
|
self.patches = patches
|
||||||
|
self.convert_func = convert_func
|
||||||
|
self.set_func = set_func
|
||||||
|
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
intermediate_dtype = weight.dtype
|
intermediate_dtype = weight.dtype
|
||||||
|
if self.convert_func is not None:
|
||||||
|
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
|
||||||
|
|
||||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||||
intermediate_dtype = torch.float32
|
intermediate_dtype = torch.float32
|
||||||
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
|
||||||
|
if self.set_func is None:
|
||||||
|
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
||||||
|
else:
|
||||||
|
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
||||||
|
|
||||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||||
|
if self.set_func is not None:
|
||||||
|
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
|
||||||
def get_key_weight(model, key):
|
def get_key_weight(model, key):
|
||||||
set_func = None
|
set_func = None
|
||||||
@ -657,13 +671,15 @@ class ModelPatcher:
|
|||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(weight_key)
|
self.patch_weight_to_device(weight_key)
|
||||||
else:
|
else:
|
||||||
m.weight_function = [LowVramPatch(weight_key, self.patches)]
|
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||||
|
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(bias_key)
|
self.patch_weight_to_device(bias_key)
|
||||||
else:
|
else:
|
||||||
m.bias_function = [LowVramPatch(bias_key, self.patches)]
|
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||||
|
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
@ -825,10 +841,12 @@ class ModelPatcher:
|
|||||||
module_mem += move_weight_functions(m, device_to)
|
module_mem += move_weight_functions(m, device_to)
|
||||||
if lowvram_possible:
|
if lowvram_possible:
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function.append(LowVramPatch(weight_key, self.patches))
|
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||||
|
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function.append(LowVramPatch(bias_key, self.patches))
|
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||||
|
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
|
|
||||||
|
|||||||
@ -416,8 +416,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
else:
|
else:
|
||||||
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||||
|
if return_weight:
|
||||||
|
return weight
|
||||||
if inplace_update:
|
if inplace_update:
|
||||||
self.weight.data.copy_(weight)
|
self.weight.data.copy_(weight)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
|
|||||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||||
from comfy_api.latest._io import _IO as io #noqa: F401
|
from . import _io as io
|
||||||
from comfy_api.latest._ui import _UI as ui #noqa: F401
|
from . import _ui as ui
|
||||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||||
@ -114,6 +114,8 @@ if TYPE_CHECKING:
|
|||||||
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||||
|
|
||||||
|
comfy_io = io # create the new alias for io
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ComfyAPI",
|
"ComfyAPI",
|
||||||
"ComfyAPISync",
|
"ComfyAPISync",
|
||||||
@ -121,4 +123,7 @@ __all__ = [
|
|||||||
"InputImpl",
|
"InputImpl",
|
||||||
"Types",
|
"Types",
|
||||||
"ComfyExtension",
|
"ComfyExtension",
|
||||||
|
"io",
|
||||||
|
"comfy_io",
|
||||||
|
"ui",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union, IO
|
||||||
import io
|
import io
|
||||||
import av
|
import av
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||||
@ -23,7 +23,7 @@ class VideoInput(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_to(
|
def save_to(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: Union[str, IO[bytes]],
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
|
|||||||
@ -1582,78 +1582,78 @@ class _UIOutput(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class _IO:
|
__all__ = [
|
||||||
FolderType = FolderType
|
"FolderType",
|
||||||
UploadType = UploadType
|
"UploadType",
|
||||||
RemoteOptions = RemoteOptions
|
"RemoteOptions",
|
||||||
NumberDisplay = NumberDisplay
|
"NumberDisplay",
|
||||||
|
|
||||||
comfytype = staticmethod(comfytype)
|
"comfytype",
|
||||||
Custom = staticmethod(Custom)
|
"Custom",
|
||||||
Input = Input
|
"Input",
|
||||||
WidgetInput = WidgetInput
|
"WidgetInput",
|
||||||
Output = Output
|
"Output",
|
||||||
ComfyTypeI = ComfyTypeI
|
"ComfyTypeI",
|
||||||
ComfyTypeIO = ComfyTypeIO
|
"ComfyTypeIO",
|
||||||
#---------------------------------
|
|
||||||
# Supported Types
|
# Supported Types
|
||||||
Boolean = Boolean
|
"Boolean",
|
||||||
Int = Int
|
"Int",
|
||||||
Float = Float
|
"Float",
|
||||||
String = String
|
"String",
|
||||||
Combo = Combo
|
"Combo",
|
||||||
MultiCombo = MultiCombo
|
"MultiCombo",
|
||||||
Image = Image
|
"Image",
|
||||||
WanCameraEmbedding = WanCameraEmbedding
|
"WanCameraEmbedding",
|
||||||
Webcam = Webcam
|
"Webcam",
|
||||||
Mask = Mask
|
"Mask",
|
||||||
Latent = Latent
|
"Latent",
|
||||||
Conditioning = Conditioning
|
"Conditioning",
|
||||||
Sampler = Sampler
|
"Sampler",
|
||||||
Sigmas = Sigmas
|
"Sigmas",
|
||||||
Noise = Noise
|
"Noise",
|
||||||
Guider = Guider
|
"Guider",
|
||||||
Clip = Clip
|
"Clip",
|
||||||
ControlNet = ControlNet
|
"ControlNet",
|
||||||
Vae = Vae
|
"Vae",
|
||||||
Model = Model
|
"Model",
|
||||||
ClipVision = ClipVision
|
"ClipVision",
|
||||||
ClipVisionOutput = ClipVisionOutput
|
"ClipVisionOutput",
|
||||||
AudioEncoder = AudioEncoder
|
"AudioEncoder",
|
||||||
AudioEncoderOutput = AudioEncoderOutput
|
"AudioEncoderOutput",
|
||||||
StyleModel = StyleModel
|
"StyleModel",
|
||||||
Gligen = Gligen
|
"Gligen",
|
||||||
UpscaleModel = UpscaleModel
|
"UpscaleModel",
|
||||||
Audio = Audio
|
"Audio",
|
||||||
Video = Video
|
"Video",
|
||||||
SVG = SVG
|
"SVG",
|
||||||
LoraModel = LoraModel
|
"LoraModel",
|
||||||
LossMap = LossMap
|
"LossMap",
|
||||||
Voxel = Voxel
|
"Voxel",
|
||||||
Mesh = Mesh
|
"Mesh",
|
||||||
Hooks = Hooks
|
"Hooks",
|
||||||
HookKeyframes = HookKeyframes
|
"HookKeyframes",
|
||||||
TimestepsRange = TimestepsRange
|
"TimestepsRange",
|
||||||
LatentOperation = LatentOperation
|
"LatentOperation",
|
||||||
FlowControl = FlowControl
|
"FlowControl",
|
||||||
Accumulation = Accumulation
|
"Accumulation",
|
||||||
Load3DCamera = Load3DCamera
|
"Load3DCamera",
|
||||||
Load3D = Load3D
|
"Load3D",
|
||||||
Load3DAnimation = Load3DAnimation
|
"Load3DAnimation",
|
||||||
Photomaker = Photomaker
|
"Photomaker",
|
||||||
Point = Point
|
"Point",
|
||||||
FaceAnalysis = FaceAnalysis
|
"FaceAnalysis",
|
||||||
BBOX = BBOX
|
"BBOX",
|
||||||
SEGS = SEGS
|
"SEGS",
|
||||||
AnyType = AnyType
|
"AnyType",
|
||||||
MultiType = MultiType
|
"MultiType",
|
||||||
#---------------------------------
|
# Other classes
|
||||||
HiddenHolder = HiddenHolder
|
"HiddenHolder",
|
||||||
Hidden = Hidden
|
"Hidden",
|
||||||
NodeInfoV1 = NodeInfoV1
|
"NodeInfoV1",
|
||||||
NodeInfoV3 = NodeInfoV3
|
"NodeInfoV3",
|
||||||
Schema = Schema
|
"Schema",
|
||||||
ComfyNode = ComfyNode
|
"ComfyNode",
|
||||||
NodeOutput = NodeOutput
|
"NodeOutput",
|
||||||
add_to_dict_v1 = staticmethod(add_to_dict_v1)
|
"add_to_dict_v1",
|
||||||
add_to_dict_v3 = staticmethod(add_to_dict_v3)
|
"add_to_dict_v3",
|
||||||
|
]
|
||||||
|
|||||||
@ -449,15 +449,16 @@ class PreviewText(_UIOutput):
|
|||||||
return {"text": (self.value,)}
|
return {"text": (self.value,)}
|
||||||
|
|
||||||
|
|
||||||
class _UI:
|
__all__ = [
|
||||||
SavedResult = SavedResult
|
"SavedResult",
|
||||||
SavedImages = SavedImages
|
"SavedImages",
|
||||||
SavedAudios = SavedAudios
|
"SavedAudios",
|
||||||
ImageSaveHelper = ImageSaveHelper
|
"ImageSaveHelper",
|
||||||
AudioSaveHelper = AudioSaveHelper
|
"AudioSaveHelper",
|
||||||
PreviewImage = PreviewImage
|
"PreviewImage",
|
||||||
PreviewMask = PreviewMask
|
"PreviewMask",
|
||||||
PreviewAudio = PreviewAudio
|
"PreviewAudio",
|
||||||
PreviewVideo = PreviewVideo
|
"PreviewVideo",
|
||||||
PreviewUI3D = PreviewUI3D
|
"PreviewUI3D",
|
||||||
PreviewText = PreviewText
|
"PreviewText",
|
||||||
|
]
|
||||||
|
|||||||
@ -269,7 +269,7 @@ def tensor_to_bytesio(
|
|||||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Named BytesIO object containing the image data.
|
Named BytesIO object containing the image data, with pointer set to the start of buffer.
|
||||||
"""
|
"""
|
||||||
if not mime_type:
|
if not mime_type:
|
||||||
mime_type = "image/png"
|
mime_type = "image/png"
|
||||||
@ -431,7 +431,7 @@ async def upload_video_to_comfyapi(
|
|||||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error getting video duration: {e}")
|
logging.error("Error getting video duration: %s", str(e))
|
||||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||||
|
|
||||||
upload_mime_type = f"video/{container.value.lower()}"
|
upload_mime_type = f"video/{container.value.lower()}"
|
||||||
|
|||||||
@ -98,7 +98,7 @@ import io
|
|||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
from aiohttp.client_exceptions import ClientError, ClientResponseError
|
from aiohttp.client_exceptions import ClientError, ClientResponseError
|
||||||
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
|
from typing import Type, Optional, Any, TypeVar, Generic, Callable
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import json
|
import json
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
@ -175,7 +175,7 @@ class ApiClient:
|
|||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
retry_delay: float = 1.0,
|
retry_delay: float = 1.0,
|
||||||
retry_backoff_factor: float = 2.0,
|
retry_backoff_factor: float = 2.0,
|
||||||
retry_status_codes: Optional[Tuple[int, ...]] = None,
|
retry_status_codes: Optional[tuple[int, ...]] = None,
|
||||||
session: Optional[aiohttp.ClientSession] = None,
|
session: Optional[aiohttp.ClientSession] = None,
|
||||||
):
|
):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
@ -199,9 +199,9 @@ class ApiClient:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_json_payload_args(
|
def _create_json_payload_args(
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: Optional[dict[str, Any]] = None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"json": data,
|
"json": data,
|
||||||
"headers": headers,
|
"headers": headers,
|
||||||
@ -209,11 +209,11 @@ class ApiClient:
|
|||||||
|
|
||||||
def _create_form_data_args(
|
def _create_form_data_args(
|
||||||
self,
|
self,
|
||||||
data: Dict[str, Any] | None,
|
data: dict[str, Any] | None,
|
||||||
files: Dict[str, Any] | None,
|
files: dict[str, Any] | None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
multipart_parser: Callable | None = None,
|
multipart_parser: Callable | None = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if headers and "Content-Type" in headers:
|
if headers and "Content-Type" in headers:
|
||||||
del headers["Content-Type"]
|
del headers["Content-Type"]
|
||||||
|
|
||||||
@ -254,9 +254,9 @@ class ApiClient:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_urlencoded_form_data_args(
|
def _create_urlencoded_form_data_args(
|
||||||
data: Dict[str, Any],
|
data: dict[str, Any],
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
headers = headers or {}
|
headers = headers or {}
|
||||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||||
return {
|
return {
|
||||||
@ -264,7 +264,7 @@ class ApiClient:
|
|||||||
"headers": headers,
|
"headers": headers,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_headers(self) -> Dict[str, str]:
|
def get_headers(self) -> dict[str, str]:
|
||||||
"""Get headers for API requests, including authentication if available"""
|
"""Get headers for API requests, including authentication if available"""
|
||||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
||||||
|
|
||||||
@ -275,7 +275,7 @@ class ApiClient:
|
|||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
|
async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
|
||||||
"""
|
"""
|
||||||
Check connectivity to determine if network issues are local or server-related.
|
Check connectivity to determine if network issues are local or server-related.
|
||||||
|
|
||||||
@ -316,14 +316,14 @@ class ApiClient:
|
|||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
path: str,
|
path: str,
|
||||||
params: Optional[Dict[str, Any]] = None,
|
params: Optional[dict[str, Any]] = None,
|
||||||
data: Optional[Dict[str, Any]] = None,
|
data: Optional[dict[str, Any]] = None,
|
||||||
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
|
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
multipart_parser: Callable | None = None,
|
multipart_parser: Callable | None = None,
|
||||||
retry_count: int = 0, # Used internally for tracking retries
|
retry_count: int = 0, # Used internally for tracking retries
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Make an HTTP request to the API with automatic retries for transient errors.
|
Make an HTTP request to the API with automatic retries for transient errors.
|
||||||
|
|
||||||
@ -359,10 +359,10 @@ class ApiClient:
|
|||||||
if params:
|
if params:
|
||||||
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
|
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
|
||||||
|
|
||||||
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
|
logging.debug("[DEBUG] Request Headers: %s", request_headers)
|
||||||
logging.debug(f"[DEBUG] Files: {files}")
|
logging.debug("[DEBUG] Files: %s", files)
|
||||||
logging.debug(f"[DEBUG] Params: {params}")
|
logging.debug("[DEBUG] Params: %s", params)
|
||||||
logging.debug(f"[DEBUG] Data: {data}")
|
logging.debug("[DEBUG] Data: %s", data)
|
||||||
|
|
||||||
if content_type == "application/x-www-form-urlencoded":
|
if content_type == "application/x-www-form-urlencoded":
|
||||||
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
|
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
|
||||||
@ -485,7 +485,7 @@ class ApiClient:
|
|||||||
retry_delay: Initial delay between retries in seconds
|
retry_delay: Initial delay between retries in seconds
|
||||||
retry_backoff_factor: Multiplier for the delay after each retry
|
retry_backoff_factor: Multiplier for the delay after each retry
|
||||||
"""
|
"""
|
||||||
headers: Dict[str, str] = {}
|
headers: dict[str, str] = {}
|
||||||
skip_auto_headers: set[str] = set()
|
skip_auto_headers: set[str] = set()
|
||||||
if content_type:
|
if content_type:
|
||||||
headers["Content-Type"] = content_type
|
headers["Content-Type"] = content_type
|
||||||
@ -558,7 +558,7 @@ class ApiClient:
|
|||||||
*req_meta,
|
*req_meta,
|
||||||
retry_count: int,
|
retry_count: int,
|
||||||
response_content: dict | str = "",
|
response_content: dict | str = "",
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
status_code = exc.status
|
status_code = exc.status
|
||||||
if status_code == 401:
|
if status_code == 401:
|
||||||
user_friendly = "Unauthorized: Please login first to use this node."
|
user_friendly = "Unauthorized: Please login first to use this node."
|
||||||
@ -592,9 +592,9 @@ class ApiClient:
|
|||||||
error_message=f"HTTP Error {exc.status}",
|
error_message=f"HTTP Error {exc.status}",
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})")
|
logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
|
||||||
if response_content:
|
if response_content:
|
||||||
logging.debug(f"[DEBUG] Response content: {response_content}")
|
logging.debug("[DEBUG] Response content: %s", response_content)
|
||||||
|
|
||||||
# Retry if eligible
|
# Retry if eligible
|
||||||
if status_code in self.retry_status_codes and retry_count < self.max_retries:
|
if status_code in self.retry_status_codes and retry_count < self.max_retries:
|
||||||
@ -659,7 +659,7 @@ class ApiEndpoint(Generic[T, R]):
|
|||||||
method: HttpMethod,
|
method: HttpMethod,
|
||||||
request_model: Type[T],
|
request_model: Type[T],
|
||||||
response_model: Type[R],
|
response_model: Type[R],
|
||||||
query_params: Optional[Dict[str, Any]] = None,
|
query_params: Optional[dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""Initialize an API endpoint definition.
|
"""Initialize an API endpoint definition.
|
||||||
|
|
||||||
@ -684,11 +684,11 @@ class SynchronousOperation(Generic[T, R]):
|
|||||||
self,
|
self,
|
||||||
endpoint: ApiEndpoint[T, R],
|
endpoint: ApiEndpoint[T, R],
|
||||||
request: T,
|
request: T,
|
||||||
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
|
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
|
||||||
api_base: str | None = None,
|
api_base: str | None = None,
|
||||||
auth_token: Optional[str] = None,
|
auth_token: Optional[str] = None,
|
||||||
comfy_api_key: Optional[str] = None,
|
comfy_api_key: Optional[str] = None,
|
||||||
auth_kwargs: Optional[Dict[str, str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
timeout: float = 7200.0,
|
timeout: float = 7200.0,
|
||||||
verify_ssl: bool = True,
|
verify_ssl: bool = True,
|
||||||
content_type: str = "application/json",
|
content_type: str = "application/json",
|
||||||
@ -729,7 +729,7 @@ class SynchronousOperation(Generic[T, R]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
request_dict: Optional[Dict[str, Any]]
|
request_dict: Optional[dict[str, Any]]
|
||||||
if isinstance(self.request, EmptyRequest):
|
if isinstance(self.request, EmptyRequest):
|
||||||
request_dict = None
|
request_dict = None
|
||||||
else:
|
else:
|
||||||
@ -738,11 +738,9 @@ class SynchronousOperation(Generic[T, R]):
|
|||||||
if isinstance(v, Enum):
|
if isinstance(v, Enum):
|
||||||
request_dict[k] = v.value
|
request_dict[k] = v.value
|
||||||
|
|
||||||
logging.debug(
|
logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
|
||||||
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
|
logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
|
||||||
)
|
logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
|
||||||
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
|
||||||
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
|
||||||
|
|
||||||
response_json = await client.request(
|
response_json = await client.request(
|
||||||
self.endpoint.method.value,
|
self.endpoint.method.value,
|
||||||
@ -757,11 +755,11 @@ class SynchronousOperation(Generic[T, R]):
|
|||||||
logging.debug("=" * 50)
|
logging.debug("=" * 50)
|
||||||
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
||||||
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
||||||
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}")
|
logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
|
||||||
logging.debug("=" * 50)
|
logging.debug("=" * 50)
|
||||||
|
|
||||||
parsed_response = self.endpoint.response_model.model_validate(response_json)
|
parsed_response = self.endpoint.response_model.model_validate(response_json)
|
||||||
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}")
|
logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
|
||||||
return parsed_response
|
return parsed_response
|
||||||
finally:
|
finally:
|
||||||
if owns_client:
|
if owns_client:
|
||||||
@ -784,14 +782,14 @@ class PollingOperation(Generic[T, R]):
|
|||||||
poll_endpoint: ApiEndpoint[EmptyRequest, R],
|
poll_endpoint: ApiEndpoint[EmptyRequest, R],
|
||||||
completed_statuses: list[str],
|
completed_statuses: list[str],
|
||||||
failed_statuses: list[str],
|
failed_statuses: list[str],
|
||||||
status_extractor: Callable[[R], str],
|
status_extractor: Callable[[R], Optional[str]],
|
||||||
progress_extractor: Callable[[R], float] | None = None,
|
progress_extractor: Callable[[R], Optional[float]] | None = None,
|
||||||
result_url_extractor: Callable[[R], str] | None = None,
|
result_url_extractor: Callable[[R], Optional[str]] | None = None,
|
||||||
request: Optional[T] = None,
|
request: Optional[T] = None,
|
||||||
api_base: str | None = None,
|
api_base: str | None = None,
|
||||||
auth_token: Optional[str] = None,
|
auth_token: Optional[str] = None,
|
||||||
comfy_api_key: Optional[str] = None,
|
comfy_api_key: Optional[str] = None,
|
||||||
auth_kwargs: Optional[Dict[str, str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
poll_interval: float = 5.0,
|
poll_interval: float = 5.0,
|
||||||
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
|
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
|
||||||
max_retries: int = 3, # Max retries per individual API call
|
max_retries: int = 3, # Max retries per individual API call
|
||||||
@ -877,7 +875,7 @@ class PollingOperation(Generic[T, R]):
|
|||||||
status = TaskStatus.PENDING
|
status = TaskStatus.PENDING
|
||||||
for poll_count in range(1, self.max_poll_attempts + 1):
|
for poll_count in range(1, self.max_poll_attempts + 1):
|
||||||
try:
|
try:
|
||||||
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
|
logging.debug("[DEBUG] Polling attempt #%s", poll_count)
|
||||||
|
|
||||||
request_dict = (
|
request_dict = (
|
||||||
None if self.request is None else self.request.model_dump(exclude_none=True)
|
None if self.request is None else self.request.model_dump(exclude_none=True)
|
||||||
@ -885,10 +883,13 @@ class PollingOperation(Generic[T, R]):
|
|||||||
|
|
||||||
if poll_count == 1:
|
if poll_count == 1:
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
|
"[DEBUG] Poll Request: %s %s",
|
||||||
|
self.poll_endpoint.method.value,
|
||||||
|
self.poll_endpoint.path,
|
||||||
)
|
)
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
|
"[DEBUG] Poll Request Data: %s",
|
||||||
|
json.dumps(request_dict, indent=2) if request_dict else "None",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Query task status
|
# Query task status
|
||||||
@ -903,7 +904,7 @@ class PollingOperation(Generic[T, R]):
|
|||||||
|
|
||||||
# Check if task is complete
|
# Check if task is complete
|
||||||
status = self._check_task_status(response_obj)
|
status = self._check_task_status(response_obj)
|
||||||
logging.debug(f"[DEBUG] Task Status: {status}")
|
logging.debug("[DEBUG] Task Status: %s", status)
|
||||||
|
|
||||||
# If progress extractor is provided, extract progress
|
# If progress extractor is provided, extract progress
|
||||||
if self.progress_extractor:
|
if self.progress_extractor:
|
||||||
@ -917,7 +918,7 @@ class PollingOperation(Generic[T, R]):
|
|||||||
result_url = self.result_url_extractor(response_obj)
|
result_url = self.result_url_extractor(response_obj)
|
||||||
if result_url:
|
if result_url:
|
||||||
message = f"Result URL: {result_url}"
|
message = f"Result URL: {result_url}"
|
||||||
logging.debug(f"[DEBUG] {message}")
|
logging.debug("[DEBUG] %s", message)
|
||||||
self._display_text_on_node(message)
|
self._display_text_on_node(message)
|
||||||
self.final_response = response_obj
|
self.final_response = response_obj
|
||||||
if self.progress_extractor:
|
if self.progress_extractor:
|
||||||
@ -925,7 +926,7 @@ class PollingOperation(Generic[T, R]):
|
|||||||
return self.final_response
|
return self.final_response
|
||||||
if status == TaskStatus.FAILED:
|
if status == TaskStatus.FAILED:
|
||||||
message = f"Task failed: {json.dumps(resp)}"
|
message = f"Task failed: {json.dumps(resp)}"
|
||||||
logging.error(f"[DEBUG] {message}")
|
logging.error("[DEBUG] %s", message)
|
||||||
raise Exception(message)
|
raise Exception(message)
|
||||||
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
||||||
# Task pending – wait
|
# Task pending – wait
|
||||||
@ -939,7 +940,12 @@ class PollingOperation(Generic[T, R]):
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
|
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e))
|
logging.warning(
|
||||||
|
"Network error (%s/%s): %s",
|
||||||
|
consecutive_errors,
|
||||||
|
max_consecutive_errors,
|
||||||
|
str(e),
|
||||||
|
)
|
||||||
await asyncio.sleep(self.poll_interval)
|
await asyncio.sleep(self.poll_interval)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# For other errors, increment count and potentially abort
|
# For other errors, increment count and potentially abort
|
||||||
@ -949,10 +955,13 @@ class PollingOperation(Generic[T, R]):
|
|||||||
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
logging.error(f"[DEBUG] Polling error: {str(e)}")
|
logging.error("[DEBUG] Polling error: %s", str(e))
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
|
"Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
|
||||||
f"Will retry in {self.poll_interval} seconds."
|
poll_count,
|
||||||
|
self.max_poll_attempts,
|
||||||
|
str(e),
|
||||||
|
self.poll_interval,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(self.poll_interval)
|
await asyncio.sleep(self.poll_interval)
|
||||||
|
|
||||||
|
|||||||
100
comfy_api_nodes/apis/pika_defs.py
Normal file
100
comfy_api_nodes/apis/pika_defs.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Pikaffect(str, Enum):
|
||||||
|
Cake_ify = "Cake-ify"
|
||||||
|
Crumble = "Crumble"
|
||||||
|
Crush = "Crush"
|
||||||
|
Decapitate = "Decapitate"
|
||||||
|
Deflate = "Deflate"
|
||||||
|
Dissolve = "Dissolve"
|
||||||
|
Explode = "Explode"
|
||||||
|
Eye_pop = "Eye-pop"
|
||||||
|
Inflate = "Inflate"
|
||||||
|
Levitate = "Levitate"
|
||||||
|
Melt = "Melt"
|
||||||
|
Peel = "Peel"
|
||||||
|
Poke = "Poke"
|
||||||
|
Squish = "Squish"
|
||||||
|
Ta_da = "Ta-da"
|
||||||
|
Tear = "Tear"
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
|
||||||
|
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
|
||||||
|
duration: Optional[int] = Field(5)
|
||||||
|
ingredientsMode: str = Field(...)
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
resolution: Optional[str] = Field('1080p')
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaGenerateResponse(BaseModel):
|
||||||
|
video_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
|
||||||
|
duration: Optional[int] = 5
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
resolution: Optional[str] = '1080p'
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
|
||||||
|
duration: Optional[int] = Field(None, ge=5, le=10)
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: str = Field(...)
|
||||||
|
resolution: Optional[str] = '1080p'
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
|
||||||
|
aspectRatio: Optional[float] = Field(
|
||||||
|
1.7777777777777777,
|
||||||
|
description='Aspect ratio (width / height)',
|
||||||
|
ge=0.4,
|
||||||
|
le=2.5,
|
||||||
|
)
|
||||||
|
duration: Optional[int] = 5
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: str = Field(...)
|
||||||
|
resolution: Optional[str] = '1080p'
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
pikaffect: Optional[str] = None
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
modifyRegionRoi: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaStatusEnum(str, Enum):
|
||||||
|
queued = "queued"
|
||||||
|
started = "started"
|
||||||
|
finished = "finished"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class PikaVideoResponse(BaseModel):
|
||||||
|
id: str = Field(...)
|
||||||
|
progress: Optional[int] = Field(None)
|
||||||
|
status: PikaStatusEnum
|
||||||
|
url: Optional[str] = Field(None)
|
||||||
@ -21,7 +21,7 @@ def get_log_directory():
|
|||||||
try:
|
try:
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating API log directory {log_dir}: {e}")
|
logger.error("Error creating API log directory %s: %s", log_dir, str(e))
|
||||||
# Fallback to base temp directory if sub-directory creation fails
|
# Fallback to base temp directory if sub-directory creation fails
|
||||||
return base_temp_dir
|
return base_temp_dir
|
||||||
return log_dir
|
return log_dir
|
||||||
@ -122,9 +122,9 @@ def log_request_response(
|
|||||||
try:
|
try:
|
||||||
with open(filepath, "w", encoding="utf-8") as f:
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
f.write("\n".join(log_content))
|
f.write("\n".join(log_content))
|
||||||
logger.debug(f"API log saved to: {filepath}")
|
logger.debug("API log saved to: %s", filepath)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error writing API log to {filepath}: {e}")
|
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -296,7 +296,7 @@ def validate_video_result_response(response) -> None:
|
|||||||
"""Validates that the Kling task result contains a video."""
|
"""Validates that the Kling task result contains a video."""
|
||||||
if not is_valid_video_response(response):
|
if not is_valid_video_response(response):
|
||||||
error_msg = f"Kling task {response.data.task_id} succeeded but no video data found in response."
|
error_msg = f"Kling task {response.data.task_id} succeeded but no video data found in response."
|
||||||
logging.error(f"Error: {error_msg}.\nResponse: {response}")
|
logging.error("Error: %s.\nResponse: %s", error_msg, response)
|
||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
|
|
||||||
|
|
||||||
@ -304,7 +304,7 @@ def validate_image_result_response(response) -> None:
|
|||||||
"""Validates that the Kling task result contains an image."""
|
"""Validates that the Kling task result contains an image."""
|
||||||
if not is_valid_image_response(response):
|
if not is_valid_image_response(response):
|
||||||
error_msg = f"Kling task {response.data.task_id} succeeded but no image data found in response."
|
error_msg = f"Kling task {response.data.task_id} succeeded but no image data found in response."
|
||||||
logging.error(f"Error: {error_msg}.\nResponse: {response}")
|
logging.error("Error: %s.\nResponse: %s", error_msg, response)
|
||||||
raise Exception(error_msg)
|
raise Exception(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -500,7 +500,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||||
)
|
)
|
||||||
logging.info(f"Generated video URL: {file_url}")
|
logging.info("Generated video URL: %s", file_url)
|
||||||
if cls.hidden.unique_id:
|
if cls.hidden.unique_id:
|
||||||
if hasattr(file_result.file, "backup_download_url"):
|
if hasattr(file_result.file, "backup_download_url"):
|
||||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||||
|
|||||||
@ -237,7 +237,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
audio_stream = None
|
audio_stream = None
|
||||||
|
|
||||||
for stream in input_container.streams:
|
for stream in input_container.streams:
|
||||||
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
|
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
|
||||||
if isinstance(stream, av.VideoStream):
|
if isinstance(stream, av.VideoStream):
|
||||||
# Create output video stream with same parameters
|
# Create output video stream with same parameters
|
||||||
video_stream = output_container.add_stream(
|
video_stream = output_container.add_stream(
|
||||||
@ -247,7 +247,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
video_stream.height = stream.height
|
video_stream.height = stream.height
|
||||||
video_stream.pix_fmt = "yuv420p"
|
video_stream.pix_fmt = "yuv420p"
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
|
"Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
|
||||||
)
|
)
|
||||||
elif isinstance(stream, av.AudioStream):
|
elif isinstance(stream, av.AudioStream):
|
||||||
# Create output audio stream with same parameters
|
# Create output audio stream with same parameters
|
||||||
@ -256,9 +256,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
)
|
)
|
||||||
audio_stream.sample_rate = stream.sample_rate
|
audio_stream.sample_rate = stream.sample_rate
|
||||||
audio_stream.layout = stream.layout
|
audio_stream.layout = stream.layout
|
||||||
logging.info(
|
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
|
||||||
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate target frame count that's divisible by 16
|
# Calculate target frame count that's divisible by 16
|
||||||
fps = input_container.streams.video[0].average_rate
|
fps = input_container.streams.video[0].average_rate
|
||||||
@ -288,9 +286,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
for packet in video_stream.encode():
|
for packet in video_stream.encode():
|
||||||
output_container.mux(packet)
|
output_container.mux(packet)
|
||||||
|
|
||||||
logging.info(
|
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
|
||||||
f"Encoded {frame_count} video frames (target: {target_frames})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode and re-encode audio frames
|
# Decode and re-encode audio frames
|
||||||
if audio_stream:
|
if audio_stream:
|
||||||
@ -308,7 +304,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
for packet in audio_stream.encode():
|
for packet in audio_stream.encode():
|
||||||
output_container.mux(packet)
|
output_container.mux(packet)
|
||||||
|
|
||||||
logging.info(f"Encoded {audio_frame_count} audio frames")
|
logging.info("Encoded %s audio frames", audio_frame_count)
|
||||||
|
|
||||||
# Close containers
|
# Close containers
|
||||||
output_container.close()
|
output_container.close()
|
||||||
|
|||||||
@ -8,30 +8,17 @@ from __future__ import annotations
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, TypeVar
|
from typing import Optional, TypeVar
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
from comfy_api.latest import ComfyExtension, comfy_io
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||||
from comfy_api_nodes.apinode_utils import (
|
from comfy_api_nodes.apinode_utils import (
|
||||||
download_url_to_video_output,
|
download_url_to_video_output,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import pika_defs
|
||||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
|
||||||
PikaBodyGenerate22I2vGenerate22I2vPost,
|
|
||||||
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
|
||||||
PikaBodyGenerate22T2vGenerate22T2vPost,
|
|
||||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
|
||||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
|
||||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
|
||||||
PikaGenerateResponse,
|
|
||||||
PikaVideoResponse,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.apis.client import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
EmptyRequest,
|
EmptyRequest,
|
||||||
@ -55,116 +42,36 @@ PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
|
|||||||
PATH_VIDEO_GET = "/proxy/pika/videos"
|
PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||||
|
|
||||||
|
|
||||||
class PikaDurationEnum(int, Enum):
|
async def execute_task(
|
||||||
integer_5 = 5
|
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
|
||||||
integer_10 = 10
|
|
||||||
|
|
||||||
|
|
||||||
class PikaResolutionEnum(str, Enum):
|
|
||||||
field_1080p = "1080p"
|
|
||||||
field_720p = "720p"
|
|
||||||
|
|
||||||
|
|
||||||
class Pikaffect(str, Enum):
|
|
||||||
Cake_ify = "Cake-ify"
|
|
||||||
Crumble = "Crumble"
|
|
||||||
Crush = "Crush"
|
|
||||||
Decapitate = "Decapitate"
|
|
||||||
Deflate = "Deflate"
|
|
||||||
Dissolve = "Dissolve"
|
|
||||||
Explode = "Explode"
|
|
||||||
Eye_pop = "Eye-pop"
|
|
||||||
Inflate = "Inflate"
|
|
||||||
Levitate = "Levitate"
|
|
||||||
Melt = "Melt"
|
|
||||||
Peel = "Peel"
|
|
||||||
Poke = "Poke"
|
|
||||||
Squish = "Squish"
|
|
||||||
Ta_da = "Ta-da"
|
|
||||||
Tear = "Tear"
|
|
||||||
|
|
||||||
|
|
||||||
class PikaApiError(Exception):
|
|
||||||
"""Exception for Pika API errors."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_video_response(response: PikaVideoResponse) -> bool:
|
|
||||||
"""Check if the video response is valid."""
|
|
||||||
return hasattr(response, "url") and response.url is not None
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
|
|
||||||
"""Check if the initial response is valid."""
|
|
||||||
return hasattr(response, "video_id") and response.video_id is not None
|
|
||||||
|
|
||||||
|
|
||||||
async def poll_for_task_status(
|
|
||||||
task_id: str,
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
) -> PikaGenerateResponse:
|
) -> comfy_io.NodeOutput:
|
||||||
polling_operation = PollingOperation(
|
task_id = (await initial_operation.execute()).video_id
|
||||||
|
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
|
||||||
poll_endpoint=ApiEndpoint(
|
poll_endpoint=ApiEndpoint(
|
||||||
path=f"{PATH_VIDEO_GET}/{task_id}",
|
path=f"{PATH_VIDEO_GET}/{task_id}",
|
||||||
method=HttpMethod.GET,
|
method=HttpMethod.GET,
|
||||||
request_model=EmptyRequest,
|
request_model=EmptyRequest,
|
||||||
response_model=PikaVideoResponse,
|
response_model=pika_defs.PikaVideoResponse,
|
||||||
),
|
),
|
||||||
completed_statuses=[
|
completed_statuses=["finished"],
|
||||||
"finished",
|
|
||||||
],
|
|
||||||
failed_statuses=["failed", "cancelled"],
|
failed_statuses=["failed", "cancelled"],
|
||||||
status_extractor=lambda response: (
|
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||||
response.status.value if response.status else None
|
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||||
),
|
|
||||||
progress_extractor=lambda response: (
|
|
||||||
response.progress if hasattr(response, "progress") else None
|
|
||||||
),
|
|
||||||
auth_kwargs=auth_kwargs,
|
auth_kwargs=auth_kwargs,
|
||||||
result_url_extractor=lambda response: (
|
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
|
||||||
response.url if hasattr(response, "url") else None
|
|
||||||
),
|
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
estimated_duration=60
|
estimated_duration=60,
|
||||||
)
|
max_poll_attempts=240,
|
||||||
return await polling_operation.execute()
|
).execute()
|
||||||
|
if not final_response.url:
|
||||||
|
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||||
async def execute_task(
|
|
||||||
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
) -> tuple[VideoFromFile]:
|
|
||||||
"""Executes the initial operation then polls for the task status until it is completed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
initial_operation: The initial operation to execute.
|
|
||||||
auth_kwargs: The authentication token(s) to use for the API call.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tuple containing the video file as a VIDEO output.
|
|
||||||
"""
|
|
||||||
initial_response = await initial_operation.execute()
|
|
||||||
if not is_valid_initial_response(initial_response):
|
|
||||||
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
|
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise PikaApiError(error_msg)
|
raise Exception(error_msg)
|
||||||
|
video_url = final_response.url
|
||||||
task_id = initial_response.video_id
|
|
||||||
final_response = await poll_for_task_status(task_id, auth_kwargs, node_id=node_id)
|
|
||||||
if not is_valid_video_response(final_response):
|
|
||||||
error_msg = (
|
|
||||||
f"Pika task {task_id} succeeded but no video data found in response."
|
|
||||||
)
|
|
||||||
logging.error(error_msg)
|
|
||||||
raise PikaApiError(error_msg)
|
|
||||||
|
|
||||||
video_url = str(final_response.url)
|
|
||||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||||
|
return comfy_io.NodeOutput(await download_url_to_video_output(video_url))
|
||||||
return (await download_url_to_video_output(video_url),)
|
|
||||||
|
|
||||||
|
|
||||||
def get_base_inputs_types() -> list[comfy_io.Input]:
|
def get_base_inputs_types() -> list[comfy_io.Input]:
|
||||||
@ -173,16 +80,12 @@ def get_base_inputs_types() -> list[comfy_io.Input]:
|
|||||||
comfy_io.String.Input("prompt_text", multiline=True),
|
comfy_io.String.Input("prompt_text", multiline=True),
|
||||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
comfy_io.String.Input("negative_prompt", multiline=True),
|
||||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||||
comfy_io.Combo.Input(
|
comfy_io.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
|
||||||
"resolution", options=PikaResolutionEnum, default=PikaResolutionEnum.field_1080p
|
comfy_io.Combo.Input("duration", options=[5, 10], default=5),
|
||||||
),
|
|
||||||
comfy_io.Combo.Input(
|
|
||||||
"duration", options=PikaDurationEnum, default=PikaDurationEnum.integer_5
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class PikaImageToVideoV2_2(comfy_io.ComfyNode):
|
class PikaImageToVideo(comfy_io.ComfyNode):
|
||||||
"""Pika 2.2 Image to Video Node."""
|
"""Pika 2.2 Image to Video Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -215,14 +118,9 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
|
|||||||
resolution: str,
|
resolution: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> comfy_io.NodeOutput:
|
||||||
# Convert image to BytesIO
|
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
image_bytes_io.seek(0)
|
|
||||||
|
|
||||||
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
||||||
|
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||||
# Prepare non-file data
|
|
||||||
pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -237,8 +135,8 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_IMAGE_TO_VIDEO,
|
path=PATH_IMAGE_TO_VIDEO,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGenerate22I2vGenerate22I2vPost,
|
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=pika_request_data,
|
request=pika_request_data,
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
@ -248,7 +146,7 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
|
|||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||||
|
|
||||||
|
|
||||||
class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
class PikaTextToVideoNode(comfy_io.ComfyNode):
|
||||||
"""Pika Text2Video v2.2 Node."""
|
"""Pika Text2Video v2.2 Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -296,10 +194,10 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_TEXT_TO_VIDEO,
|
path=PATH_TEXT_TO_VIDEO,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGenerate22T2vGenerate22T2vPost,
|
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=PikaBodyGenerate22T2vGenerate22T2vPost(
|
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -313,7 +211,7 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
|||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||||
|
|
||||||
|
|
||||||
class PikaScenesV2_2(comfy_io.ComfyNode):
|
class PikaScenes(comfy_io.ComfyNode):
|
||||||
"""PikaScenes v2.2 Node."""
|
"""PikaScenes v2.2 Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -389,7 +287,6 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
|||||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> comfy_io.NodeOutput:
|
||||||
# Convert all passed images to BytesIO
|
|
||||||
all_image_bytes_io = []
|
all_image_bytes_io = []
|
||||||
for image in [
|
for image in [
|
||||||
image_ingredient_1,
|
image_ingredient_1,
|
||||||
@ -399,16 +296,14 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
|||||||
image_ingredient_5,
|
image_ingredient_5,
|
||||||
]:
|
]:
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
all_image_bytes_io.append(tensor_to_bytesio(image))
|
||||||
image_bytes_io.seek(0)
|
|
||||||
all_image_bytes_io.append(image_bytes_io)
|
|
||||||
|
|
||||||
pika_files = [
|
pika_files = [
|
||||||
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
||||||
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
||||||
]
|
]
|
||||||
|
|
||||||
pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
||||||
ingredientsMode=ingredients_mode,
|
ingredientsMode=ingredients_mode,
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
@ -425,8 +320,8 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_PIKASCENES,
|
path=PATH_PIKASCENES,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=pika_request_data,
|
request=pika_request_data,
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
@ -477,22 +372,16 @@ class PikAdditionsNode(comfy_io.ComfyNode):
|
|||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> comfy_io.NodeOutput:
|
||||||
# Convert video to BytesIO
|
|
||||||
video_bytes_io = BytesIO()
|
video_bytes_io = BytesIO()
|
||||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||||
video_bytes_io.seek(0)
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
# Convert image to BytesIO
|
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
image_bytes_io.seek(0)
|
|
||||||
|
|
||||||
pika_files = {
|
pika_files = {
|
||||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||||
"image": ("image.png", image_bytes_io, "image/png"),
|
"image": ("image.png", image_bytes_io, "image/png"),
|
||||||
}
|
}
|
||||||
|
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||||
# Prepare non-file data
|
|
||||||
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -505,8 +394,8 @@ class PikAdditionsNode(comfy_io.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_PIKADDITIONS,
|
path=PATH_PIKADDITIONS,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=pika_request_data,
|
request=pika_request_data,
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
@ -529,11 +418,25 @@ class PikaSwapsNode(comfy_io.ComfyNode):
|
|||||||
category="api node/video/Pika",
|
category="api node/video/Pika",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Video.Input("video", tooltip="The video to swap an object in."),
|
comfy_io.Video.Input("video", tooltip="The video to swap an object in."),
|
||||||
comfy_io.Image.Input("image", tooltip="The image used to replace the masked object in the video."),
|
comfy_io.Image.Input(
|
||||||
comfy_io.Mask.Input("mask", tooltip="Use the mask to define areas in the video to replace"),
|
"image",
|
||||||
comfy_io.String.Input("prompt_text", multiline=True),
|
tooltip="The image used to replace the masked object in the video.",
|
||||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
optional=True,
|
||||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
),
|
||||||
|
comfy_io.Mask.Input(
|
||||||
|
"mask",
|
||||||
|
tooltip="Use the mask to define areas in the video to replace.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.String.Input("prompt_text", multiline=True, optional=True),
|
||||||
|
comfy_io.String.Input("negative_prompt", multiline=True, optional=True),
|
||||||
|
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"region_to_modify",
|
||||||
|
multiline=True,
|
||||||
|
optional=True,
|
||||||
|
tooltip="Plaintext description of the object / region to modify.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[comfy_io.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
@ -548,41 +451,29 @@ class PikaSwapsNode(comfy_io.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
video: VideoInput,
|
video: VideoInput,
|
||||||
image: torch.Tensor,
|
image: Optional[torch.Tensor] = None,
|
||||||
mask: torch.Tensor,
|
mask: Optional[torch.Tensor] = None,
|
||||||
prompt_text: str,
|
prompt_text: str = "",
|
||||||
negative_prompt: str,
|
negative_prompt: str = "",
|
||||||
seed: int,
|
seed: int = 0,
|
||||||
|
region_to_modify: str = "",
|
||||||
) -> comfy_io.NodeOutput:
|
) -> comfy_io.NodeOutput:
|
||||||
# Convert video to BytesIO
|
|
||||||
video_bytes_io = BytesIO()
|
video_bytes_io = BytesIO()
|
||||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||||
video_bytes_io.seek(0)
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
# Convert mask to binary mask with three channels
|
|
||||||
mask = torch.round(mask)
|
|
||||||
mask = mask.repeat(1, 3, 1, 1)
|
|
||||||
|
|
||||||
# Convert 3-channel binary mask to BytesIO
|
|
||||||
mask_bytes_io = BytesIO()
|
|
||||||
mask_bytes_io.write(mask.numpy().astype(np.uint8))
|
|
||||||
mask_bytes_io.seek(0)
|
|
||||||
|
|
||||||
# Convert image to BytesIO
|
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
|
||||||
image_bytes_io.seek(0)
|
|
||||||
|
|
||||||
pika_files = {
|
pika_files = {
|
||||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||||
"image": ("image.png", image_bytes_io, "image/png"),
|
|
||||||
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
|
|
||||||
}
|
}
|
||||||
|
if mask is not None:
|
||||||
|
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
|
||||||
|
if image is not None:
|
||||||
|
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
|
||||||
|
|
||||||
# Prepare non-file data
|
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||||
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||||
)
|
)
|
||||||
auth = {
|
auth = {
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
@ -590,10 +481,10 @@ class PikaSwapsNode(comfy_io.ComfyNode):
|
|||||||
}
|
}
|
||||||
initial_operation = SynchronousOperation(
|
initial_operation = SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_PIKADDITIONS,
|
path=PATH_PIKASWAPS,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=pika_request_data,
|
request=pika_request_data,
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
@ -616,7 +507,7 @@ class PikaffectsNode(comfy_io.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
||||||
comfy_io.Combo.Input(
|
comfy_io.Combo.Input(
|
||||||
"pikaffect", options=Pikaffect, default="Cake-ify"
|
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
|
||||||
),
|
),
|
||||||
comfy_io.String.Input("prompt_text", multiline=True),
|
comfy_io.String.Input("prompt_text", multiline=True),
|
||||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
comfy_io.String.Input("negative_prompt", multiline=True),
|
||||||
@ -648,10 +539,10 @@ class PikaffectsNode(comfy_io.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_PIKAFFECTS,
|
path=PATH_PIKAFFECTS,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||||
pikaffect=pikaffect,
|
pikaffect=pikaffect,
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
@ -664,7 +555,7 @@ class PikaffectsNode(comfy_io.ComfyNode):
|
|||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||||
|
|
||||||
|
|
||||||
class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
|
class PikaStartEndFrameNode(comfy_io.ComfyNode):
|
||||||
"""PikaFrames v2.2 Node."""
|
"""PikaFrames v2.2 Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -711,10 +602,10 @@ class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
|
|||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=PATH_PIKAFRAMES,
|
path=PATH_PIKAFRAMES,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||||
response_model=PikaGenerateResponse,
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
),
|
),
|
||||||
request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -732,13 +623,13 @@ class PikaApiNodesExtension(ComfyExtension):
|
|||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
PikaImageToVideoV2_2,
|
PikaImageToVideo,
|
||||||
PikaTextToVideoNodeV2_2,
|
PikaTextToVideoNode,
|
||||||
PikaScenesV2_2,
|
PikaScenes,
|
||||||
PikAdditionsNode,
|
PikAdditionsNode,
|
||||||
PikaSwapsNode,
|
PikaSwapsNode,
|
||||||
PikaffectsNode,
|
PikaffectsNode,
|
||||||
PikaStartEndFrameNode2_2,
|
PikaStartEndFrameNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -172,16 +172,16 @@ async def create_generate_task(
|
|||||||
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
|
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
|
||||||
subscription_key = response.jobs.subscription_key
|
subscription_key = response.jobs.subscription_key
|
||||||
task_uuid = response.uuid
|
task_uuid = response.uuid
|
||||||
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
|
logging.info("[ Rodin3D API - Submit Jobs ] UUID: %s", task_uuid)
|
||||||
return task_uuid, subscription_key
|
return task_uuid, subscription_key
|
||||||
|
|
||||||
|
|
||||||
def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
||||||
all_done = all(job.status == JobStatus.Done for job in response.jobs)
|
all_done = all(job.status == JobStatus.Done for job in response.jobs)
|
||||||
status_list = [str(job.status) for job in response.jobs]
|
status_list = [str(job.status) for job in response.jobs]
|
||||||
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
|
logging.info("[ Rodin3D API - CheckStatus ] Generate Status: %s", status_list)
|
||||||
if any(job.status == JobStatus.Failed for job in response.jobs):
|
if any(job.status == JobStatus.Failed for job in response.jobs):
|
||||||
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
|
logging.error("[ Rodin3D API - CheckStatus ] Generate Failed: %s, Please try again.", status_list)
|
||||||
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
|
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
|
||||||
if all_done:
|
if all_done:
|
||||||
return "DONE"
|
return "DONE"
|
||||||
@ -235,7 +235,7 @@ async def download_files(url_list, task_uuid):
|
|||||||
file_path = os.path.join(save_path, file_name)
|
file_path = os.path.join(save_path, file_name)
|
||||||
if file_path.endswith(".glb"):
|
if file_path.endswith(".glb"):
|
||||||
model_file_path = file_path
|
model_file_path = file_path
|
||||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
|
||||||
max_retries = 5
|
max_retries = 5
|
||||||
for attempt in range(max_retries):
|
for attempt in range(max_retries):
|
||||||
try:
|
try:
|
||||||
@ -246,7 +246,7 @@ async def download_files(url_list, task_uuid):
|
|||||||
f.write(chunk)
|
f.write(chunk)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e))
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
logging.info("Retrying...")
|
logging.info("Retrying...")
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|||||||
@ -215,7 +215,7 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
|||||||
initial_response = await initial_operation.execute()
|
initial_response = await initial_operation.execute()
|
||||||
operation_name = initial_response.name
|
operation_name = initial_response.name
|
||||||
|
|
||||||
logging.info(f"Veo generation started with operation name: {operation_name}")
|
logging.info("Veo generation started with operation name: %s", operation_name)
|
||||||
|
|
||||||
# Define status extractor function
|
# Define status extractor function
|
||||||
def status_extractor(response):
|
def status_extractor(response):
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
def resize_mask(mask, shape):
|
def resize_mask(mask, shape):
|
||||||
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
||||||
@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
|
|||||||
return out_image, out_alpha
|
return out_image, out_alpha
|
||||||
|
|
||||||
|
|
||||||
class PorterDuffImageComposite:
|
class PorterDuffImageComposite(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="PorterDuffImageComposite",
|
||||||
"source": ("IMAGE",),
|
display_name="Porter-Duff Image Composite",
|
||||||
"source_alpha": ("MASK",),
|
category="mask/compositing",
|
||||||
"destination": ("IMAGE",),
|
inputs=[
|
||||||
"destination_alpha": ("MASK",),
|
io.Image.Input("source"),
|
||||||
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
|
io.Mask.Input("source_alpha"),
|
||||||
},
|
io.Image.Input("destination"),
|
||||||
}
|
io.Mask.Input("destination_alpha"),
|
||||||
|
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
io.Mask.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
@classmethod
|
||||||
FUNCTION = "composite"
|
def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
|
||||||
CATEGORY = "mask/compositing"
|
|
||||||
|
|
||||||
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
|
|
||||||
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
||||||
out_images = []
|
out_images = []
|
||||||
out_alphas = []
|
out_alphas = []
|
||||||
@ -150,45 +157,48 @@ class PorterDuffImageComposite:
|
|||||||
out_images.append(out_image)
|
out_images.append(out_image)
|
||||||
out_alphas.append(out_alpha.squeeze(2))
|
out_alphas.append(out_alpha.squeeze(2))
|
||||||
|
|
||||||
result = (torch.stack(out_images), torch.stack(out_alphas))
|
return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class SplitImageWithAlpha:
|
class SplitImageWithAlpha(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="SplitImageWithAlpha",
|
||||||
"image": ("IMAGE",),
|
display_name="Split Image with Alpha",
|
||||||
}
|
category="mask/compositing",
|
||||||
}
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
io.Mask.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask/compositing"
|
@classmethod
|
||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
def execute(cls, image: torch.Tensor) -> io.NodeOutput:
|
||||||
FUNCTION = "split_image_with_alpha"
|
|
||||||
|
|
||||||
def split_image_with_alpha(self, image: torch.Tensor):
|
|
||||||
out_images = [i[:,:,:3] for i in image]
|
out_images = [i[:,:,:3] for i in image]
|
||||||
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
||||||
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class JoinImageWithAlpha:
|
class JoinImageWithAlpha(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="JoinImageWithAlpha",
|
||||||
"image": ("IMAGE",),
|
display_name="Join Image with Alpha",
|
||||||
"alpha": ("MASK",),
|
category="mask/compositing",
|
||||||
}
|
inputs=[
|
||||||
}
|
io.Image.Input("image"),
|
||||||
|
io.Mask.Input("alpha"),
|
||||||
|
],
|
||||||
|
outputs=[io.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask/compositing"
|
@classmethod
|
||||||
RETURN_TYPES = ("IMAGE",)
|
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||||
FUNCTION = "join_image_with_alpha"
|
|
||||||
|
|
||||||
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
|
|
||||||
batch_size = min(len(image), len(alpha))
|
batch_size = min(len(image), len(alpha))
|
||||||
out_images = []
|
out_images = []
|
||||||
|
|
||||||
@ -196,19 +206,18 @@ class JoinImageWithAlpha:
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||||
|
|
||||||
result = (torch.stack(out_images),)
|
return io.NodeOutput(torch.stack(out_images))
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class CompositingExtension(ComfyExtension):
|
||||||
"PorterDuffImageComposite": PorterDuffImageComposite,
|
@override
|
||||||
"SplitImageWithAlpha": SplitImageWithAlpha,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"JoinImageWithAlpha": JoinImageWithAlpha,
|
return [
|
||||||
}
|
PorterDuffImageComposite,
|
||||||
|
SplitImageWithAlpha,
|
||||||
|
JoinImageWithAlpha,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
async def comfy_entrypoint() -> CompositingExtension:
|
||||||
"PorterDuffImageComposite": "Porter-Duff Image Composite",
|
return CompositingExtension()
|
||||||
"SplitImageWithAlpha": "Split Image with Alpha",
|
|
||||||
"JoinImageWithAlpha": "Join Image with Alpha",
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,60 +1,80 @@
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
class CLIPTextEncodeFlux:
|
|
||||||
|
class CLIPTextEncodeFlux(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"clip": ("CLIP", ),
|
node_id="CLIPTextEncodeFlux",
|
||||||
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
category="advanced/conditioning/flux",
|
||||||
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
inputs=[
|
||||||
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
|
io.Clip.Input("clip"),
|
||||||
}}
|
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
||||||
FUNCTION = "encode"
|
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
@classmethod
|
||||||
|
def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput:
|
||||||
def encode(self, clip, clip_l, t5xxl, guidance):
|
|
||||||
tokens = clip.tokenize(clip_l)
|
tokens = clip.tokenize(clip_l)
|
||||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||||
|
|
||||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
|
||||||
|
|
||||||
class FluxGuidance:
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class FluxGuidance(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"conditioning": ("CONDITIONING", ),
|
node_id="FluxGuidance",
|
||||||
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
|
category="advanced/conditioning/flux",
|
||||||
}}
|
inputs=[
|
||||||
|
io.Conditioning.Input("conditioning"),
|
||||||
|
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
@classmethod
|
||||||
FUNCTION = "append"
|
def execute(cls, conditioning, guidance) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
|
||||||
|
|
||||||
def append(self, conditioning, guidance):
|
|
||||||
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
|
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
|
||||||
return (c, )
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
|
append = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class FluxDisableGuidance:
|
class FluxDisableGuidance(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"conditioning": ("CONDITIONING", ),
|
node_id="FluxDisableGuidance",
|
||||||
}}
|
category="advanced/conditioning/flux",
|
||||||
|
description="This node completely disables the guidance embed on Flux and Flux like models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("conditioning"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
@classmethod
|
||||||
FUNCTION = "append"
|
def execute(cls, conditioning) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
|
||||||
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
|
|
||||||
|
|
||||||
def append(self, conditioning):
|
|
||||||
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
|
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
|
||||||
return (c, )
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
|
append = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||||
@ -78,52 +98,73 @@ PREFERED_KONTEXT_RESOLUTIONS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class FluxKontextImageScale:
|
class FluxKontextImageScale(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"image": ("IMAGE", ),
|
return io.Schema(
|
||||||
},
|
node_id="FluxKontextImageScale",
|
||||||
}
|
category="advanced/conditioning/flux",
|
||||||
|
description="This node resizes the image to one that is more optimal for flux kontext.",
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
@classmethod
|
||||||
FUNCTION = "scale"
|
def execute(cls, image) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
|
||||||
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
|
|
||||||
|
|
||||||
def scale(self, image):
|
|
||||||
width = image.shape[2]
|
width = image.shape[2]
|
||||||
height = image.shape[1]
|
height = image.shape[1]
|
||||||
aspect_ratio = width / height
|
aspect_ratio = width / height
|
||||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
||||||
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
||||||
return (image, )
|
return io.NodeOutput(image)
|
||||||
|
|
||||||
|
scale = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class FluxKontextMultiReferenceLatentMethod:
|
class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"conditioning": ("CONDITIONING", ),
|
node_id="FluxKontextMultiReferenceLatentMethod",
|
||||||
"reference_latents_method": (("offset", "index", "uxo/uno"), ),
|
category="advanced/conditioning/flux",
|
||||||
}}
|
inputs=[
|
||||||
|
io.Conditioning.Input("conditioning"),
|
||||||
|
io.Combo.Input(
|
||||||
|
"reference_latents_method",
|
||||||
|
options=["offset", "index", "uxo/uno"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
@classmethod
|
||||||
FUNCTION = "append"
|
def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput:
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
|
||||||
|
|
||||||
def append(self, conditioning, reference_latents_method):
|
|
||||||
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
|
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
|
||||||
reference_latents_method = "uxo"
|
reference_latents_method = "uxo"
|
||||||
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
|
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
|
||||||
return (c, )
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
append = execute # TODO: remove
|
||||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
|
||||||
"FluxGuidance": FluxGuidance,
|
|
||||||
"FluxDisableGuidance": FluxDisableGuidance,
|
class FluxExtension(ComfyExtension):
|
||||||
"FluxKontextImageScale": FluxKontextImageScale,
|
@override
|
||||||
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
}
|
return [
|
||||||
|
CLIPTextEncodeFlux,
|
||||||
|
FluxGuidance,
|
||||||
|
FluxDisableGuidance,
|
||||||
|
FluxKontextImageScale,
|
||||||
|
FluxKontextMultiReferenceLatentMethod,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> FluxExtension:
|
||||||
|
return FluxExtension()
|
||||||
|
|||||||
@ -2,6 +2,8 @@ import comfy.utils
|
|||||||
import comfy_extras.nodes_post_processing
|
import comfy_extras.nodes_post_processing
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
import nodes
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||||
@ -13,17 +15,23 @@ def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
|||||||
return latent
|
return latent
|
||||||
|
|
||||||
|
|
||||||
class LatentAdd:
|
class LatentAdd(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
return io.Schema(
|
||||||
|
node_id="LatentAdd",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples1, samples2):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
@ -31,19 +39,25 @@ class LatentAdd:
|
|||||||
|
|
||||||
s2 = reshape_latent_to(s1.shape, s2)
|
s2 = reshape_latent_to(s1.shape, s2)
|
||||||
samples_out["samples"] = s1 + s2
|
samples_out["samples"] = s1 + s2
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentSubtract:
|
class LatentSubtract(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
return io.Schema(
|
||||||
|
node_id="LatentSubtract",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples1, samples2):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
@ -51,41 +65,49 @@ class LatentSubtract:
|
|||||||
|
|
||||||
s2 = reshape_latent_to(s1.shape, s2)
|
s2 = reshape_latent_to(s1.shape, s2)
|
||||||
samples_out["samples"] = s1 - s2
|
samples_out["samples"] = s1 - s2
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentMultiply:
|
class LatentMultiply(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples": ("LATENT",),
|
return io.Schema(
|
||||||
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
node_id="LatentMultiply",
|
||||||
}}
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples, multiplier) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples, multiplier):
|
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
|
|
||||||
s1 = samples["samples"]
|
s1 = samples["samples"]
|
||||||
samples_out["samples"] = s1 * multiplier
|
samples_out["samples"] = s1 * multiplier
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentInterpolate:
|
class LatentInterpolate(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",),
|
return io.Schema(
|
||||||
"samples2": ("LATENT",),
|
node_id="LatentInterpolate",
|
||||||
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
category="latent/advanced",
|
||||||
}}
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples1, samples2, ratio) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples1, samples2, ratio):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
@ -104,19 +126,26 @@ class LatentInterpolate:
|
|||||||
st = torch.nan_to_num(t / mt)
|
st = torch.nan_to_num(t / mt)
|
||||||
|
|
||||||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentConcat:
|
class LatentConcat(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}
|
return io.Schema(
|
||||||
|
node_id="LatentConcat",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples1, samples2, dim) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples1, samples2, dim):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
@ -136,22 +165,27 @@ class LatentConcat:
|
|||||||
dim = -3
|
dim = -3
|
||||||
|
|
||||||
samples_out["samples"] = torch.cat(c, dim=dim)
|
samples_out["samples"] = torch.cat(c, dim=dim)
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentCut:
|
class LatentCut(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"samples": ("LATENT",),
|
return io.Schema(
|
||||||
"dim": (["x", "y", "t"], ),
|
node_id="LatentCut",
|
||||||
"index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}),
|
category="latent/advanced",
|
||||||
"amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}}
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.Combo.Input("dim", options=["x", "y", "t"]),
|
||||||
|
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples, dim, index, amount) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples, dim, index, amount):
|
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
|
|
||||||
s1 = samples["samples"]
|
s1 = samples["samples"]
|
||||||
@ -171,19 +205,25 @@ class LatentCut:
|
|||||||
amount = min(-index, amount)
|
amount = min(-index, amount)
|
||||||
|
|
||||||
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentBatch:
|
class LatentBatch(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
return io.Schema(
|
||||||
|
node_id="LatentBatch",
|
||||||
|
category="latent/batch",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "batch"
|
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/batch"
|
|
||||||
|
|
||||||
def batch(self, samples1, samples2):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
s2 = samples2["samples"]
|
s2 = samples2["samples"]
|
||||||
@ -192,20 +232,25 @@ class LatentBatch:
|
|||||||
s = torch.cat((s1, s2), dim=0)
|
s = torch.cat((s1, s2), dim=0)
|
||||||
samples_out["samples"] = s
|
samples_out["samples"] = s
|
||||||
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentBatchSeedBehavior:
|
class LatentBatchSeedBehavior(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples": ("LATENT",),
|
return io.Schema(
|
||||||
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
|
node_id="LatentBatchSeedBehavior",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples, seed_behavior) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples, seed_behavior):
|
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
latent = samples["samples"]
|
latent = samples["samples"]
|
||||||
if seed_behavior == "random":
|
if seed_behavior == "random":
|
||||||
@ -215,41 +260,50 @@ class LatentBatchSeedBehavior:
|
|||||||
batch_number = samples_out.get("batch_index", [0])[0]
|
batch_number = samples_out.get("batch_index", [0])[0]
|
||||||
samples_out["batch_index"] = [batch_number] * latent.shape[0]
|
samples_out["batch_index"] = [batch_number] * latent.shape[0]
|
||||||
|
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentApplyOperation:
|
class LatentApplyOperation(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples": ("LATENT",),
|
return io.Schema(
|
||||||
"operation": ("LATENT_OPERATION",),
|
node_id="LatentApplyOperation",
|
||||||
}}
|
category="latent/advanced/operations",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.LatentOperation.Input("operation"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples, operation) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced/operations"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def op(self, samples, operation):
|
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
|
|
||||||
s1 = samples["samples"]
|
s1 = samples["samples"]
|
||||||
samples_out["samples"] = operation(latent=s1)
|
samples_out["samples"] = operation(latent=s1)
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentApplyOperationCFG:
|
class LatentApplyOperationCFG(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"operation": ("LATENT_OPERATION",),
|
node_id="LatentApplyOperationCFG",
|
||||||
}}
|
category="latent/advanced/operations",
|
||||||
RETURN_TYPES = ("MODEL",)
|
is_experimental=True,
|
||||||
FUNCTION = "patch"
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.LatentOperation.Input("operation"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "latent/advanced/operations"
|
@classmethod
|
||||||
EXPERIMENTAL = True
|
def execute(cls, model, operation) -> io.NodeOutput:
|
||||||
|
|
||||||
def patch(self, model, operation):
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
def pre_cfg_function(args):
|
def pre_cfg_function(args):
|
||||||
@ -261,21 +315,25 @@ class LatentApplyOperationCFG:
|
|||||||
return conds_out
|
return conds_out
|
||||||
|
|
||||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class LatentOperationTonemapReinhard:
|
class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
return io.Schema(
|
||||||
}}
|
node_id="LatentOperationTonemapReinhard",
|
||||||
|
category="latent/advanced/operations",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.LatentOperation.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT_OPERATION",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, multiplier) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced/operations"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def op(self, multiplier):
|
|
||||||
def tonemap_reinhard(latent, **kwargs):
|
def tonemap_reinhard(latent, **kwargs):
|
||||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||||
normalized_latent = latent / latent_vector_magnitude
|
normalized_latent = latent / latent_vector_magnitude
|
||||||
@ -291,39 +349,27 @@ class LatentOperationTonemapReinhard:
|
|||||||
new_magnitude *= top
|
new_magnitude *= top
|
||||||
|
|
||||||
return normalized_latent * new_magnitude
|
return normalized_latent * new_magnitude
|
||||||
return (tonemap_reinhard,)
|
return io.NodeOutput(tonemap_reinhard)
|
||||||
|
|
||||||
class LatentOperationSharpen:
|
class LatentOperationSharpen(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"sharpen_radius": ("INT", {
|
node_id="LatentOperationSharpen",
|
||||||
"default": 9,
|
category="latent/advanced/operations",
|
||||||
"min": 1,
|
is_experimental=True,
|
||||||
"max": 31,
|
inputs=[
|
||||||
"step": 1
|
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
|
||||||
}),
|
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
|
||||||
"sigma": ("FLOAT", {
|
io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
|
||||||
"default": 1.0,
|
],
|
||||||
"min": 0.1,
|
outputs=[
|
||||||
"max": 10.0,
|
io.LatentOperation.Output(),
|
||||||
"step": 0.1
|
],
|
||||||
}),
|
)
|
||||||
"alpha": ("FLOAT", {
|
|
||||||
"default": 0.1,
|
|
||||||
"min": 0.0,
|
|
||||||
"max": 5.0,
|
|
||||||
"step": 0.01
|
|
||||||
}),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT_OPERATION",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced/operations"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def op(self, sharpen_radius, sigma, alpha):
|
|
||||||
def sharpen(latent, **kwargs):
|
def sharpen(latent, **kwargs):
|
||||||
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
||||||
normalized_latent = latent / luminance
|
normalized_latent = latent / luminance
|
||||||
@ -340,19 +386,27 @@ class LatentOperationSharpen:
|
|||||||
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||||
|
|
||||||
return luminance * sharpened
|
return luminance * sharpened
|
||||||
return (sharpen,)
|
return io.NodeOutput(sharpen)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"LatentAdd": LatentAdd,
|
class LatentExtension(ComfyExtension):
|
||||||
"LatentSubtract": LatentSubtract,
|
@override
|
||||||
"LatentMultiply": LatentMultiply,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"LatentInterpolate": LatentInterpolate,
|
return [
|
||||||
"LatentConcat": LatentConcat,
|
LatentAdd,
|
||||||
"LatentCut": LatentCut,
|
LatentSubtract,
|
||||||
"LatentBatch": LatentBatch,
|
LatentMultiply,
|
||||||
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
LatentInterpolate,
|
||||||
"LatentApplyOperation": LatentApplyOperation,
|
LatentConcat,
|
||||||
"LatentApplyOperationCFG": LatentApplyOperationCFG,
|
LatentCut,
|
||||||
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
|
LatentBatch,
|
||||||
"LatentOperationSharpen": LatentOperationSharpen,
|
LatentBatchSeedBehavior,
|
||||||
}
|
LatentApplyOperation,
|
||||||
|
LatentApplyOperationCFG,
|
||||||
|
LatentOperationTonemapReinhard,
|
||||||
|
LatentOperationSharpen,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> LatentExtension:
|
||||||
|
return LatentExtension()
|
||||||
|
|||||||
@ -5,6 +5,8 @@ import folder_paths
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
|
|||||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||||
return output_sd
|
return output_sd
|
||||||
|
|
||||||
class LoraSave:
|
class LoraSave(io.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoraSave",
|
||||||
|
display_name="Extract and Save Lora",
|
||||||
|
category="_for_testing",
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
||||||
|
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
|
||||||
|
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())),
|
||||||
|
io.Boolean.Input("bias_diff", default=True),
|
||||||
|
io.Model.Input(
|
||||||
|
"model_diff",
|
||||||
|
tooltip="The ModelSubtract output to be converted to a lora.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
io.Clip.Input(
|
||||||
|
"text_encoder_diff",
|
||||||
|
tooltip="The CLIPSubtract output to be converted to a lora.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
|
||||||
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
|
||||||
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
|
|
||||||
"lora_type": (tuple(LORA_TYPES.keys()),),
|
|
||||||
"bias_diff": ("BOOLEAN", {"default": True}),
|
|
||||||
},
|
|
||||||
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
|
|
||||||
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
|
|
||||||
}
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
FUNCTION = "save"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
|
||||||
|
|
||||||
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
|
||||||
if model_diff is None and text_encoder_diff is None:
|
if model_diff is None and text_encoder_diff is None:
|
||||||
return {}
|
return io.NodeOutput()
|
||||||
|
|
||||||
lora_type = LORA_TYPES.get(lora_type)
|
lora_type = LORA_TYPES.get(lora_type)
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||||
|
|
||||||
output_sd = {}
|
output_sd = {}
|
||||||
if model_diff is not None:
|
if model_diff is not None:
|
||||||
@ -108,12 +118,16 @@ class LoraSave:
|
|||||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||||
return {}
|
return io.NodeOutput()
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"LoraSave": LoraSave
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
class LoraSaveExtension(ComfyExtension):
|
||||||
"LoraSave": "Extract and Save Lora"
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
LoraSave,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> LoraSaveExtension:
|
||||||
|
return LoraSaveExtension()
|
||||||
|
|||||||
@ -1,24 +1,33 @@
|
|||||||
|
from typing_extensions import override
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
class PatchModelAddDownscale:
|
|
||||||
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
class PatchModelAddDownscale(io.ComfyNode):
|
||||||
|
UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
|
node_id="PatchModelAddDownscale",
|
||||||
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
|
display_name="PatchModelAddDownscale (Kohya Deep Shrink)",
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
category="model_patches/unet",
|
||||||
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
inputs=[
|
||||||
"downscale_after_skip": ("BOOLEAN", {"default": True}),
|
io.Model.Input("model"),
|
||||||
"downscale_method": (s.upscale_methods,),
|
io.Int.Input("block_number", default=3, min=1, max=32, step=1),
|
||||||
"upscale_method": (s.upscale_methods,),
|
io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
|
||||||
}}
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
|
||||||
FUNCTION = "patch"
|
io.Boolean.Input("downscale_after_skip", default=True),
|
||||||
|
io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
|
||||||
|
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "model_patches/unet"
|
@classmethod
|
||||||
|
def execute(cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method) -> io.NodeOutput:
|
||||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
|
||||||
model_sampling = model.get_model_object("model_sampling")
|
model_sampling = model.get_model_object("model_sampling")
|
||||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||||
@ -41,13 +50,21 @@ class PatchModelAddDownscale:
|
|||||||
else:
|
else:
|
||||||
m.set_model_input_block_patch(input_block_patch)
|
m.set_model_input_block_patch(input_block_patch)
|
||||||
m.set_model_output_block_patch(output_block_patch)
|
m.set_model_output_block_patch(output_block_patch)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"PatchModelAddDownscale": PatchModelAddDownscale,
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
# Sampling
|
# Sampling
|
||||||
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
|
"PatchModelAddDownscale": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ModelDownscaleExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
PatchModelAddDownscale,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ModelDownscaleExtension:
|
||||||
|
return ModelDownscaleExtension()
|
||||||
|
|||||||
@ -3,64 +3,83 @@ import comfy.sd
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import nodes
|
import nodes
|
||||||
import torch
|
import torch
|
||||||
import comfy_extras.nodes_slg
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from comfy_extras.nodes_slg import SkipLayerGuidanceDiT
|
||||||
|
|
||||||
|
|
||||||
class TripleCLIPLoader:
|
class TripleCLIPLoader(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), )
|
return io.Schema(
|
||||||
}}
|
node_id="TripleCLIPLoader",
|
||||||
RETURN_TYPES = ("CLIP",)
|
category="advanced/loaders",
|
||||||
FUNCTION = "load_clip"
|
description="[Recipes]\n\nsd3: clip-l, clip-g, t5",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")),
|
||||||
|
io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")),
|
||||||
|
io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Clip.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
@classmethod
|
||||||
|
def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput:
|
||||||
DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5"
|
|
||||||
|
|
||||||
def load_clip(self, clip_name1, clip_name2, clip_name3):
|
|
||||||
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
|
||||||
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
|
||||||
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
|
||||||
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (clip,)
|
return io.NodeOutput(clip)
|
||||||
|
|
||||||
|
load_clip = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class EmptySD3LatentImage:
|
class EmptySD3LatentImage(io.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.device = comfy.model_management.intermediate_device()
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="EmptySD3LatentImage",
|
||||||
|
category="latent/sd3",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
|
||||||
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
return io.NodeOutput({"samples":latent})
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
|
||||||
FUNCTION = "generate"
|
|
||||||
|
|
||||||
CATEGORY = "latent/sd3"
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1):
|
|
||||||
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
|
|
||||||
return ({"samples":latent}, )
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncodeSD3:
|
class CLIPTextEncodeSD3(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"clip": ("CLIP", ),
|
node_id="CLIPTextEncodeSD3",
|
||||||
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
category="advanced/conditioning",
|
||||||
"clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
inputs=[
|
||||||
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
io.Clip.Input("clip"),
|
||||||
"empty_padding": (["none", "empty_prompt"], )
|
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||||
}}
|
io.String.Input("clip_g", multiline=True, dynamic_prompts=True),
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
||||||
FUNCTION = "encode"
|
io.Combo.Input("empty_padding", options=["none", "empty_prompt"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning"
|
@classmethod
|
||||||
|
def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput:
|
||||||
def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding):
|
|
||||||
no_padding = empty_padding == "none"
|
no_padding = empty_padding == "none"
|
||||||
|
|
||||||
tokens = clip.tokenize(clip_g)
|
tokens = clip.tokenize(clip_g)
|
||||||
@ -82,57 +101,112 @@ class CLIPTextEncodeSD3:
|
|||||||
tokens["l"] += empty["l"]
|
tokens["l"] += empty["l"]
|
||||||
while len(tokens["l"]) > len(tokens["g"]):
|
while len(tokens["l"]) > len(tokens["g"]):
|
||||||
tokens["g"] += empty["g"]
|
tokens["g"] += empty["g"]
|
||||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
|
||||||
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
class ControlNetApplySD3(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": {"positive": ("CONDITIONING", ),
|
return io.Schema(
|
||||||
"negative": ("CONDITIONING", ),
|
node_id="ControlNetApplySD3",
|
||||||
"control_net": ("CONTROL_NET", ),
|
display_name="Apply Controlnet with VAE",
|
||||||
"vae": ("VAE", ),
|
category="conditioning/controlnet",
|
||||||
"image": ("IMAGE", ),
|
inputs=[
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Conditioning.Input("positive"),
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
io.Conditioning.Input("negative"),
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
io.ControlNet.Input("control_net"),
|
||||||
}}
|
io.Vae.Input("vae"),
|
||||||
CATEGORY = "conditioning/controlnet"
|
io.Image.Input("image"),
|
||||||
DEPRECATED = True
|
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
],
|
||||||
|
is_deprecated=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput:
|
||||||
|
if strength == 0:
|
||||||
|
return io.NodeOutput(positive, negative)
|
||||||
|
|
||||||
|
control_hint = image.movedim(-1, 1)
|
||||||
|
cnets = {}
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for conditioning in [positive, negative]:
|
||||||
|
c = []
|
||||||
|
for t in conditioning:
|
||||||
|
d = t[1].copy()
|
||||||
|
|
||||||
|
prev_cnet = d.get('control', None)
|
||||||
|
if prev_cnet in cnets:
|
||||||
|
c_net = cnets[prev_cnet]
|
||||||
|
else:
|
||||||
|
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent),
|
||||||
|
vae=vae, extra_concat=[])
|
||||||
|
c_net.set_previous_controlnet(prev_cnet)
|
||||||
|
cnets[prev_cnet] = c_net
|
||||||
|
|
||||||
|
d['control'] = c_net
|
||||||
|
d['control_apply_to_uncond'] = False
|
||||||
|
n = [t[0], d]
|
||||||
|
c.append(n)
|
||||||
|
out.append(c)
|
||||||
|
return io.NodeOutput(out[0], out[1])
|
||||||
|
|
||||||
|
apply_controlnet = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
|
class SkipLayerGuidanceSD3(io.ComfyNode):
|
||||||
'''
|
'''
|
||||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||||
Experimental implementation by Dango233@StabilityAI.
|
Experimental implementation by Dango233@StabilityAI.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"model": ("MODEL", ),
|
return io.Schema(
|
||||||
"layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
node_id="SkipLayerGuidanceSD3",
|
||||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
category="advanced/guidance",
|
||||||
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
|
||||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
|
inputs=[
|
||||||
}}
|
io.Model.Input("model"),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.String.Input("layers", default="7, 8, 9", multiline=False),
|
||||||
FUNCTION = "skip_guidance_sd3"
|
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
|
||||||
|
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/guidance"
|
@classmethod
|
||||||
|
def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput:
|
||||||
|
return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
|
||||||
|
|
||||||
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
|
skip_guidance_sd3 = execute # TODO: remove
|
||||||
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class SD3Extension(ComfyExtension):
|
||||||
"TripleCLIPLoader": TripleCLIPLoader,
|
@override
|
||||||
"EmptySD3LatentImage": EmptySD3LatentImage,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
return [
|
||||||
"ControlNetApplySD3": ControlNetApplySD3,
|
TripleCLIPLoader,
|
||||||
"SkipLayerGuidanceSD3": SkipLayerGuidanceSD3,
|
EmptySD3LatentImage,
|
||||||
}
|
CLIPTextEncodeSD3,
|
||||||
|
ControlNetApplySD3,
|
||||||
|
SkipLayerGuidanceSD3,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
# Sampling
|
async def comfy_entrypoint() -> SD3Extension:
|
||||||
"ControlNetApplySD3": "Apply Controlnet with VAE",
|
return SD3Extension()
|
||||||
}
|
|
||||||
|
|||||||
@ -1,33 +1,40 @@
|
|||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import re
|
import re
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
class SkipLayerGuidanceDiT:
|
class SkipLayerGuidanceDiT(io.ComfyNode):
|
||||||
'''
|
'''
|
||||||
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
|
||||||
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
|
||||||
Original experimental implementation for SD3 by Dango233@StabilityAI.
|
Original experimental implementation for SD3 by Dango233@StabilityAI.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"model": ("MODEL", ),
|
return io.Schema(
|
||||||
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
node_id="SkipLayerGuidanceDiT",
|
||||||
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
category="advanced/guidance",
|
||||||
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
|
description="Generic version of SkipLayerGuidance node that can be used on every DiT model.",
|
||||||
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
|
is_experimental=True,
|
||||||
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}),
|
inputs=[
|
||||||
"rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Model.Input("model"),
|
||||||
}}
|
io.String.Input("double_layers", default="7, 8, 9"),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.String.Input("single_layers", default="7, 8, 9"),
|
||||||
FUNCTION = "skip_guidance"
|
io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1),
|
||||||
EXPERIMENTAL = True
|
io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model."
|
@classmethod
|
||||||
|
def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0) -> io.NodeOutput:
|
||||||
CATEGORY = "advanced/guidance"
|
|
||||||
|
|
||||||
def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0):
|
|
||||||
# check if layer is comma separated integers
|
# check if layer is comma separated integers
|
||||||
def skip(args, extra_args):
|
def skip(args, extra_args):
|
||||||
return args
|
return args
|
||||||
@ -43,7 +50,7 @@ class SkipLayerGuidanceDiT:
|
|||||||
single_layers = [int(i) for i in single_layers]
|
single_layers = [int(i) for i in single_layers]
|
||||||
|
|
||||||
if len(double_layers) == 0 and len(single_layers) == 0:
|
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||||
return (model, )
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
def post_cfg_function(args):
|
def post_cfg_function(args):
|
||||||
model = args["model"]
|
model = args["model"]
|
||||||
@ -76,29 +83,36 @@ class SkipLayerGuidanceDiT:
|
|||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
m.set_model_sampler_post_cfg_function(post_cfg_function)
|
||||||
|
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class SkipLayerGuidanceDiTSimple:
|
skip_guidance = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class SkipLayerGuidanceDiTSimple(io.ComfyNode):
|
||||||
'''
|
'''
|
||||||
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
|
Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.
|
||||||
'''
|
'''
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"model": ("MODEL", ),
|
return io.Schema(
|
||||||
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
node_id="SkipLayerGuidanceDiTSimple",
|
||||||
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
|
category="advanced/guidance",
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.",
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
is_experimental=True,
|
||||||
}}
|
inputs=[
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Model.Input("model"),
|
||||||
FUNCTION = "skip_guidance"
|
io.String.Input("double_layers", default="7, 8, 9"),
|
||||||
EXPERIMENTAL = True
|
io.String.Input("single_layers", default="7, 8, 9"),
|
||||||
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass."
|
@classmethod
|
||||||
|
def execute(cls, model, start_percent, end_percent, double_layers="", single_layers="") -> io.NodeOutput:
|
||||||
CATEGORY = "advanced/guidance"
|
|
||||||
|
|
||||||
def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""):
|
|
||||||
def skip(args, extra_args):
|
def skip(args, extra_args):
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -113,7 +127,7 @@ class SkipLayerGuidanceDiTSimple:
|
|||||||
single_layers = [int(i) for i in single_layers]
|
single_layers = [int(i) for i in single_layers]
|
||||||
|
|
||||||
if len(double_layers) == 0 and len(single_layers) == 0:
|
if len(double_layers) == 0 and len(single_layers) == 0:
|
||||||
return (model, )
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
def calc_cond_batch_function(args):
|
def calc_cond_batch_function(args):
|
||||||
x = args["input"]
|
x = args["input"]
|
||||||
@ -144,9 +158,19 @@ class SkipLayerGuidanceDiTSimple:
|
|||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
|
m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function)
|
||||||
|
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
skip_guidance = execute # TODO: remove
|
||||||
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
|
|
||||||
"SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple,
|
|
||||||
}
|
class SkipLayerGuidanceExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
SkipLayerGuidanceDiT,
|
||||||
|
SkipLayerGuidanceDiTSimple,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> SkipLayerGuidanceExtension:
|
||||||
|
return SkipLayerGuidanceExtension()
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from comfy import model_management
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from spandrel_extra_arches import EXTRA_REGISTRY
|
from spandrel_extra_arches import EXTRA_REGISTRY
|
||||||
@ -13,17 +15,23 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class UpscaleModelLoader:
|
class UpscaleModelLoader(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ),
|
return io.Schema(
|
||||||
}}
|
node_id="UpscaleModelLoader",
|
||||||
RETURN_TYPES = ("UPSCALE_MODEL",)
|
display_name="Load Upscale Model",
|
||||||
FUNCTION = "load_model"
|
category="loaders",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.UpscaleModel.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
@classmethod
|
||||||
|
def execute(cls, model_name) -> io.NodeOutput:
|
||||||
def load_model(self, model_name):
|
|
||||||
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
model_path = folder_paths.get_full_path_or_raise("upscale_models", model_name)
|
||||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||||
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
||||||
@ -33,21 +41,29 @@ class UpscaleModelLoader:
|
|||||||
if not isinstance(out, ImageModelDescriptor):
|
if not isinstance(out, ImageModelDescriptor):
|
||||||
raise Exception("Upscale model must be a single-image model.")
|
raise Exception("Upscale model must be a single-image model.")
|
||||||
|
|
||||||
return (out, )
|
return io.NodeOutput(out)
|
||||||
|
|
||||||
|
load_model = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class ImageUpscaleWithModel:
|
class ImageUpscaleWithModel(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "upscale_model": ("UPSCALE_MODEL",),
|
return io.Schema(
|
||||||
"image": ("IMAGE",),
|
node_id="ImageUpscaleWithModel",
|
||||||
}}
|
display_name="Upscale Image (using Model)",
|
||||||
RETURN_TYPES = ("IMAGE",)
|
category="image/upscaling",
|
||||||
FUNCTION = "upscale"
|
inputs=[
|
||||||
|
io.UpscaleModel.Input("upscale_model"),
|
||||||
|
io.Image.Input("image"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "image/upscaling"
|
@classmethod
|
||||||
|
def execute(cls, upscale_model, image) -> io.NodeOutput:
|
||||||
def upscale(self, upscale_model, image):
|
|
||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
|
|
||||||
memory_required = model_management.module_size(upscale_model.model)
|
memory_required = model_management.module_size(upscale_model.model)
|
||||||
@ -75,9 +91,19 @@ class ImageUpscaleWithModel:
|
|||||||
|
|
||||||
upscale_model.to("cpu")
|
upscale_model.to("cpu")
|
||||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||||
return (s,)
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
upscale = execute # TODO: remove
|
||||||
"UpscaleModelLoader": UpscaleModelLoader,
|
|
||||||
"ImageUpscaleWithModel": ImageUpscaleWithModel
|
|
||||||
}
|
class UpscaleModelExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
UpscaleModelLoader,
|
||||||
|
ImageUpscaleWithModel,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> UpscaleModelExtension:
|
||||||
|
return UpscaleModelExtension()
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -2027,7 +2027,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"DiffControlNetLoader": "Load ControlNet Model (diff)",
|
"DiffControlNetLoader": "Load ControlNet Model (diff)",
|
||||||
"StyleModelLoader": "Load Style Model",
|
"StyleModelLoader": "Load Style Model",
|
||||||
"CLIPVisionLoader": "Load CLIP Vision",
|
"CLIPVisionLoader": "Load CLIP Vision",
|
||||||
"UpscaleModelLoader": "Load Upscale Model",
|
|
||||||
"UNETLoader": "Load Diffusion Model",
|
"UNETLoader": "Load Diffusion Model",
|
||||||
# Conditioning
|
# Conditioning
|
||||||
"CLIPVisionEncode": "CLIP Vision Encode",
|
"CLIPVisionEncode": "CLIP Vision Encode",
|
||||||
@ -2065,7 +2064,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LoadImageOutput": "Load Image (from Outputs)",
|
"LoadImageOutput": "Load Image (from Outputs)",
|
||||||
"ImageScale": "Upscale Image",
|
"ImageScale": "Upscale Image",
|
||||||
"ImageScaleBy": "Upscale Image By",
|
"ImageScaleBy": "Upscale Image By",
|
||||||
"ImageUpscaleWithModel": "Upscale Image (using Model)",
|
|
||||||
"ImageInvert": "Invert Image",
|
"ImageInvert": "Invert Image",
|
||||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||||
"ImageBatch": "Batch Images",
|
"ImageBatch": "Batch Images",
|
||||||
|
|||||||
@ -61,7 +61,6 @@ messages_control.disable = [
|
|||||||
# next warnings should be fixed in future
|
# next warnings should be fixed in future
|
||||||
"bad-classmethod-argument", # Class method should have 'cls' as first argument
|
"bad-classmethod-argument", # Class method should have 'cls' as first argument
|
||||||
"wrong-import-order", # Standard imports should be placed before third party imports
|
"wrong-import-order", # Standard imports should be placed before third party imports
|
||||||
"logging-fstring-interpolation", # Use lazy % formatting in logging functions
|
|
||||||
"ungrouped-imports",
|
"ungrouped-imports",
|
||||||
"unnecessary-pass",
|
"unnecessary-pass",
|
||||||
"unnecessary-lambda-assignment",
|
"unnecessary-lambda-assignment",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user