mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-06 01:37:45 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
dfc47e0611
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -22,7 +22,7 @@ body:
|
|||||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||||
options:
|
options:
|
||||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||||
required: true
|
required: false
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: Expected Behavior
|
label: Expected Behavior
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/user-support.yml
vendored
2
.github/ISSUE_TEMPLATE/user-support.yml
vendored
@ -18,7 +18,7 @@ body:
|
|||||||
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
|
||||||
options:
|
options:
|
||||||
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
|
||||||
required: true
|
required: false
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: Your question
|
label: Your question
|
||||||
|
|||||||
27
CODEOWNERS
27
CODEOWNERS
@ -5,20 +5,21 @@
|
|||||||
# Inlined the team members for now.
|
# Inlined the team members for now.
|
||||||
|
|
||||||
# Maintainers
|
# Maintainers
|
||||||
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
|
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne @guill
|
||||||
|
|
||||||
# Python web server
|
# Python web server
|
||||||
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||||
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||||
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
|
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne @guill
|
||||||
|
|
||||||
# Node developers
|
# Node developers
|
||||||
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||||
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
|
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||||
|
/comfy_api_nodes/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne @guill
|
||||||
|
|||||||
@ -44,6 +44,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||||
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
|
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
|
||||||
|
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
|
||||||
- Video Models
|
- Video Models
|
||||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.49"
|
__version__ = "0.3.51"
|
||||||
|
|||||||
@ -363,10 +363,17 @@ class UserManager():
|
|||||||
if not overwrite and os.path.exists(path):
|
if not overwrite and os.path.exists(path):
|
||||||
return web.Response(status=409, text="File already exists")
|
return web.Response(status=409, text="File already exists")
|
||||||
|
|
||||||
body = await request.read()
|
try:
|
||||||
|
body = await request.read()
|
||||||
|
|
||||||
with open(path, "wb") as f:
|
with open(path, "wb") as f:
|
||||||
f.write(body)
|
f.write(body)
|
||||||
|
except OSError as e:
|
||||||
|
logging.warning(f"Error saving file '{path}': {e}")
|
||||||
|
return web.Response(
|
||||||
|
status=400,
|
||||||
|
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
|
||||||
|
)
|
||||||
|
|
||||||
user_path = self.get_request_user_filepath(request, None)
|
user_path = self.get_request_user_filepath(request, None)
|
||||||
if full_info:
|
if full_info:
|
||||||
|
|||||||
@ -127,6 +127,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
|
|
||||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||||
|
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||||
parser.add_argument("--disable-smart-memory", action="store_true",
|
parser.add_argument("--disable-smart-memory", action="store_true",
|
||||||
help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||||
|
|||||||
@ -155,6 +155,7 @@ class Configuration(dict):
|
|||||||
cache_classic (bool): WARNING: Unused. Use the old style (aggressive) caching.
|
cache_classic (bool): WARNING: Unused. Use the old style (aggressive) caching.
|
||||||
cache_none (bool): Reduced RAM/VRAM usage at the expense of executing every node for each run.
|
cache_none (bool): Reduced RAM/VRAM usage at the expense of executing every node for each run.
|
||||||
async_offload (bool): Use async weight offloading.
|
async_offload (bool): Use async weight offloading.
|
||||||
|
force_non_blocking (bool): Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.
|
||||||
default_hashing_function (str): Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.
|
default_hashing_function (str): Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.
|
||||||
mmap_torch_files (bool): Use mmap when loading ckpt/pt files.
|
mmap_torch_files (bool): Use mmap when loading ckpt/pt files.
|
||||||
disable_mmap (bool): Don't use mmap when loading safetensors.
|
disable_mmap (bool): Don't use mmap when loading safetensors.
|
||||||
@ -274,6 +275,7 @@ class Configuration(dict):
|
|||||||
self.cache_classic: bool = False
|
self.cache_classic: bool = False
|
||||||
self.cache_none: bool = False
|
self.cache_none: bool = False
|
||||||
self.async_offload: bool = False
|
self.async_offload: bool = False
|
||||||
|
self.force_non_blocking: bool = False
|
||||||
self.default_hashing_function: str = 'sha256'
|
self.default_hashing_function: str = 'sha256'
|
||||||
self.mmap_torch_files: bool = False
|
self.mmap_torch_files: bool = False
|
||||||
self.disable_mmap: bool = False
|
self.disable_mmap: bool = False
|
||||||
|
|||||||
@ -98,7 +98,7 @@ class CLIPTextModel_(torch.nn.Module):
|
|||||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
|
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
|
||||||
if embeds is not None:
|
if embeds is not None:
|
||||||
x = embeds + ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
x = embeds + ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -756,6 +756,13 @@ class PromptExecutor:
|
|||||||
if ex is not None and self.raise_exceptions:
|
if ex is not None and self.raise_exceptions:
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
|
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs=None):
|
||||||
|
if execute_outputs is None:
|
||||||
|
execute_outputs = []
|
||||||
|
if extra_data is None:
|
||||||
|
extra_data = {}
|
||||||
|
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||||
|
|
||||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
# torchao and potentially other optimization approaches break when the models are created in inference mode
|
# torchao and potentially other optimization approaches break when the models are created in inference mode
|
||||||
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
|
# todo: this should really be backpropagated to code which creates ModelPatchers via lazy evaluation rather than globally checked here
|
||||||
@ -1109,7 +1116,7 @@ def full_type_name(klass):
|
|||||||
|
|
||||||
|
|
||||||
@tracer.start_as_current_span("Validate Prompt")
|
@tracer.start_as_current_span("Validate Prompt")
|
||||||
async def validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typing.Any], partial_execution_list: typing.Union[list[str], None]=None) -> ValidationTuple:
|
async def validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typing.Any], partial_execution_list: typing.Union[list[str], None] = None) -> ValidationTuple:
|
||||||
# todo: partial_execution_list=None, because nobody uses these features
|
# todo: partial_execution_list=None, because nobody uses these features
|
||||||
res = await _validate_prompt(prompt_id, prompt, partial_execution_list)
|
res = await _validate_prompt(prompt_id, prompt, partial_execution_list)
|
||||||
if not res.valid:
|
if not res.valid:
|
||||||
@ -1132,7 +1139,7 @@ async def validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typ
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typing.Any], partial_execution_list: typing.Union[list[str], None]=None) -> ValidationTuple:
|
async def _validate_prompt(prompt_id: typing.Any, prompt: typing.Mapping[str, typing.Any], partial_execution_list: typing.Union[list[str], None] = None) -> ValidationTuple:
|
||||||
outputs = set()
|
outputs = set()
|
||||||
for x in prompt:
|
for x in prompt:
|
||||||
if 'class_type' not in prompt[x]:
|
if 'class_type' not in prompt[x]:
|
||||||
|
|||||||
@ -109,6 +109,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
|
|||||||
ModelPaths(["photomaker"], supported_extensions=set(supported_pt_extensions)),
|
ModelPaths(["photomaker"], supported_extensions=set(supported_pt_extensions)),
|
||||||
ModelPaths(["classifiers"], supported_extensions=set()),
|
ModelPaths(["classifiers"], supported_extensions=set()),
|
||||||
ModelPaths(["huggingface"], supported_extensions=set()),
|
ModelPaths(["huggingface"], supported_extensions=set()),
|
||||||
|
ModelPaths(["model_patches"], supported_extensions=set(supported_pt_extensions)),
|
||||||
hf_cache_paths,
|
hf_cache_paths,
|
||||||
hf_xet,
|
hf_xet,
|
||||||
]
|
]
|
||||||
|
|||||||
540
comfy/context_windows.py
Normal file
540
comfy/context_windows.py
Normal file
@ -0,0 +1,540 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import logging
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.patcher_extension
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_base import BaseModel
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
from comfy.controlnet import ControlBase
|
||||||
|
|
||||||
|
|
||||||
|
class ContextWindowABC(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_tensor(self, full: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get torch.Tensor applicable to current window.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_window(self, full: torch.Tensor, to_add: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply torch.Tensor of window to the full tensor, in place. Returns reference to updated full tensor, not a copy.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Not implemented.")
|
||||||
|
|
||||||
|
class ContextHandlerABC(ABC):
|
||||||
|
def __init__(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||||
|
raise NotImplementedError("Not implemented.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: ContextWindowABC, device=None) -> list:
|
||||||
|
raise NotImplementedError("Not implemented.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
|
raise NotImplementedError("Not implemented.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class IndexListContextWindow(ContextWindowABC):
|
||||||
|
def __init__(self, index_list: list[int], dim: int=0):
|
||||||
|
self.index_list = index_list
|
||||||
|
self.context_length = len(index_list)
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
||||||
|
if dim is None:
|
||||||
|
dim = self.dim
|
||||||
|
if dim == 0 and full.shape[dim] == 1:
|
||||||
|
return full
|
||||||
|
idx = [slice(None)] * dim + [self.index_list]
|
||||||
|
return full[idx].to(device)
|
||||||
|
|
||||||
|
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||||
|
if dim is None:
|
||||||
|
dim = self.dim
|
||||||
|
idx = [slice(None)] * dim + [self.index_list]
|
||||||
|
full[idx] += to_add
|
||||||
|
return full
|
||||||
|
|
||||||
|
|
||||||
|
class IndexListCallbacks:
|
||||||
|
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||||
|
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||||
|
EXECUTE_START = "execute_start"
|
||||||
|
EXECUTE_CLEANUP = "execute_cleanup"
|
||||||
|
|
||||||
|
def init_callbacks(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ContextSchedule:
|
||||||
|
name: str
|
||||||
|
func: Callable
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ContextFuseMethod:
|
||||||
|
name: str
|
||||||
|
func: Callable
|
||||||
|
|
||||||
|
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||||
|
class IndexListContextHandler(ContextHandlerABC):
|
||||||
|
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
||||||
|
self.context_schedule = context_schedule
|
||||||
|
self.fuse_method = fuse_method
|
||||||
|
self.context_length = context_length
|
||||||
|
self.context_overlap = context_overlap
|
||||||
|
self.context_stride = context_stride
|
||||||
|
self.closed_loop = closed_loop
|
||||||
|
self.dim = dim
|
||||||
|
self._step = 0
|
||||||
|
|
||||||
|
self.callbacks = {}
|
||||||
|
|
||||||
|
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||||
|
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||||
|
if x_in.size(self.dim) > self.context_length:
|
||||||
|
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||||
|
if control.previous_controlnet is not None:
|
||||||
|
self.prepare_control_objects(control.previous_controlnet, device)
|
||||||
|
return control
|
||||||
|
|
||||||
|
def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: IndexListContextWindow, device=None) -> list:
|
||||||
|
if cond_in is None:
|
||||||
|
return None
|
||||||
|
# reuse or resize cond items to match context requirements
|
||||||
|
resized_cond = []
|
||||||
|
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||||
|
for actual_cond in cond_in:
|
||||||
|
resized_actual_cond = actual_cond.copy()
|
||||||
|
# now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary
|
||||||
|
for key in actual_cond:
|
||||||
|
try:
|
||||||
|
cond_item = actual_cond[key]
|
||||||
|
if isinstance(cond_item, torch.Tensor):
|
||||||
|
# check that tensor is the expected length - x.size(0)
|
||||||
|
if self.dim < cond_item.ndim and cond_item.size(self.dim) == x_in.size(self.dim):
|
||||||
|
# if so, it's subsetting time - tell controls the expected indeces so they can handle them
|
||||||
|
actual_cond_item = window.get_tensor(cond_item)
|
||||||
|
resized_actual_cond[key] = actual_cond_item.to(device)
|
||||||
|
else:
|
||||||
|
resized_actual_cond[key] = cond_item.to(device)
|
||||||
|
# look for control
|
||||||
|
elif key == "control":
|
||||||
|
resized_actual_cond[key] = self.prepare_control_objects(cond_item, device)
|
||||||
|
elif isinstance(cond_item, dict):
|
||||||
|
new_cond_item = cond_item.copy()
|
||||||
|
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||||
|
for cond_key, cond_value in new_cond_item.items():
|
||||||
|
if isinstance(cond_value, torch.Tensor):
|
||||||
|
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
||||||
|
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||||
|
# if has cond that is a Tensor, check if needs to be subset
|
||||||
|
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
|
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
||||||
|
elif cond_key == "num_video_frames": # for SVD
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||||
|
new_cond_item[cond_key].cond = window.context_length
|
||||||
|
resized_actual_cond[key] = new_cond_item
|
||||||
|
else:
|
||||||
|
resized_actual_cond[key] = cond_item
|
||||||
|
finally:
|
||||||
|
del cond_item # just in case to prevent VRAM issues
|
||||||
|
resized_cond.append(resized_actual_cond)
|
||||||
|
return resized_cond
|
||||||
|
|
||||||
|
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
|
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||||
|
matches = torch.nonzero(mask)
|
||||||
|
if torch.numel(matches) == 0:
|
||||||
|
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||||
|
self._step = int(matches[0].item())
|
||||||
|
|
||||||
|
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||||
|
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||||
|
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||||
|
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
||||||
|
return context_windows
|
||||||
|
|
||||||
|
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
|
self.set_step(timestep, model_options)
|
||||||
|
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||||
|
enumerated_context_windows = list(enumerate(context_windows))
|
||||||
|
|
||||||
|
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||||
|
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||||
|
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||||
|
else:
|
||||||
|
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||||
|
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||||
|
|
||||||
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||||
|
callback(self, model, x_in, conds, timestep, model_options)
|
||||||
|
|
||||||
|
for enum_window in enumerated_context_windows:
|
||||||
|
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||||
|
for result in results:
|
||||||
|
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||||
|
conds_final, counts_final, biases_final)
|
||||||
|
try:
|
||||||
|
# finalize conds
|
||||||
|
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||||
|
# relative is already normalized, so return as is
|
||||||
|
del counts_final
|
||||||
|
return conds_final
|
||||||
|
else:
|
||||||
|
# normalize conds via division by context usage counts
|
||||||
|
for i in range(len(conds_final)):
|
||||||
|
conds_final[i] /= counts_final[i]
|
||||||
|
del counts_final
|
||||||
|
return conds_final
|
||||||
|
finally:
|
||||||
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||||
|
callback(self, model, x_in, conds, timestep, model_options)
|
||||||
|
|
||||||
|
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||||
|
model_options, device=None, first_device=None):
|
||||||
|
results: list[ContextResults] = []
|
||||||
|
for window_idx, window in enumerated_context_windows:
|
||||||
|
# allow processing to end between context window executions for faster Cancel
|
||||||
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||||
|
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||||
|
|
||||||
|
# update exposed params
|
||||||
|
model_options["transformer_options"]["context_window"] = window
|
||||||
|
# get subsections of x, timestep, conds
|
||||||
|
sub_x = window.get_tensor(x_in, device)
|
||||||
|
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
||||||
|
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
||||||
|
|
||||||
|
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||||
|
if device is not None:
|
||||||
|
for i in range(len(sub_conds_out)):
|
||||||
|
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||||
|
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_conds, window: IndexListContextWindow, window_idx: int, total_windows: int, timestep: torch.Tensor,
|
||||||
|
conds_final: list[torch.Tensor], counts_final: list[torch.Tensor], biases_final: list[torch.Tensor]):
|
||||||
|
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||||
|
for pos, idx in enumerate(window.index_list):
|
||||||
|
# bias is the influence of a specific index in relation to the whole context window
|
||||||
|
bias = 1 - abs(idx - (window.index_list[0] + window.index_list[-1]) / 2) / ((window.index_list[-1] - window.index_list[0] + 1e-2) / 2)
|
||||||
|
bias = max(1e-2, bias)
|
||||||
|
# take weighted average relative to total bias of current idx
|
||||||
|
for i in range(len(sub_conds_out)):
|
||||||
|
bias_total = biases_final[i][idx]
|
||||||
|
prev_weight = (bias_total / (bias_total + bias))
|
||||||
|
new_weight = (bias / (bias_total + bias))
|
||||||
|
# account for dims of tensors
|
||||||
|
idx_window = [slice(None)] * self.dim + [idx]
|
||||||
|
pos_window = [slice(None)] * self.dim + [pos]
|
||||||
|
# apply new values
|
||||||
|
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||||
|
biases_final[i][idx] = bias_total + bias
|
||||||
|
else:
|
||||||
|
# add conds and counts based on weights of fuse method
|
||||||
|
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
||||||
|
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||||
|
for i in range(len(sub_conds_out)):
|
||||||
|
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||||
|
window.add_window(counts_final[i], weights_tensor)
|
||||||
|
|
||||||
|
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.COMBINE_CONTEXT_WINDOW_RESULTS, self.callbacks):
|
||||||
|
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
||||||
|
# limit noise_shape length to context_length for more accurate vram use estimation
|
||||||
|
model_options = kwargs.get("model_options", None)
|
||||||
|
if model_options is None:
|
||||||
|
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||||
|
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||||
|
if handler is not None:
|
||||||
|
noise_shape = list(noise_shape)
|
||||||
|
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||||
|
return executor(model, noise_shape, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||||
|
model.add_wrapper_with_key(
|
||||||
|
comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING,
|
||||||
|
"ContextWindows_prepare_sampling",
|
||||||
|
_prepare_sampling_wrapper
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||||
|
total_dims = len(x_in.shape)
|
||||||
|
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||||
|
for _ in range(dim):
|
||||||
|
weights_tensor = weights_tensor.unsqueeze(0)
|
||||||
|
for _ in range(total_dims - dim - 1):
|
||||||
|
weights_tensor = weights_tensor.unsqueeze(-1)
|
||||||
|
return weights_tensor
|
||||||
|
|
||||||
|
def get_shape_for_dim(x_in: torch.Tensor, dim: int) -> list[int]:
|
||||||
|
total_dims = len(x_in.shape)
|
||||||
|
shape = []
|
||||||
|
for _ in range(dim):
|
||||||
|
shape.append(1)
|
||||||
|
shape.append(x_in.shape[dim])
|
||||||
|
for _ in range(total_dims - dim - 1):
|
||||||
|
shape.append(1)
|
||||||
|
return shape
|
||||||
|
|
||||||
|
class ContextSchedules:
|
||||||
|
UNIFORM_LOOPED = "looped_uniform"
|
||||||
|
UNIFORM_STANDARD = "standard_uniform"
|
||||||
|
STATIC_STANDARD = "standard_static"
|
||||||
|
BATCHED = "batched"
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/neggles/animatediff-cli/blob/main/src/animatediff/pipelines/context.py
|
||||||
|
def create_windows_uniform_looped(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||||
|
windows = []
|
||||||
|
if num_frames < handler.context_length:
|
||||||
|
windows.append(list(range(num_frames)))
|
||||||
|
return windows
|
||||||
|
|
||||||
|
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||||
|
# obtain uniform windows as normal, looping and all
|
||||||
|
for context_step in 1 << np.arange(context_stride):
|
||||||
|
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||||
|
for j in range(
|
||||||
|
int(ordered_halving(handler._step) * context_step) + pad,
|
||||||
|
num_frames + pad + (0 if handler.closed_loop else -handler.context_overlap),
|
||||||
|
(handler.context_length * context_step - handler.context_overlap),
|
||||||
|
):
|
||||||
|
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||||
|
|
||||||
|
return windows
|
||||||
|
|
||||||
|
def create_windows_uniform_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||||
|
# unlike looped, uniform_straight does NOT allow windows that loop back to the beginning;
|
||||||
|
# instead, they get shifted to the corresponding end of the frames.
|
||||||
|
# in the case that a window (shifted or not) is identical to the previous one, it gets skipped.
|
||||||
|
windows = []
|
||||||
|
if num_frames <= handler.context_length:
|
||||||
|
windows.append(list(range(num_frames)))
|
||||||
|
return windows
|
||||||
|
|
||||||
|
context_stride = min(handler.context_stride, int(np.ceil(np.log2(num_frames / handler.context_length))) + 1)
|
||||||
|
# first, obtain uniform windows as normal, looping and all
|
||||||
|
for context_step in 1 << np.arange(context_stride):
|
||||||
|
pad = int(round(num_frames * ordered_halving(handler._step)))
|
||||||
|
for j in range(
|
||||||
|
int(ordered_halving(handler._step) * context_step) + pad,
|
||||||
|
num_frames + pad + (-handler.context_overlap),
|
||||||
|
(handler.context_length * context_step - handler.context_overlap),
|
||||||
|
):
|
||||||
|
windows.append([e % num_frames for e in range(j, j + handler.context_length * context_step, context_step)])
|
||||||
|
|
||||||
|
# now that windows are created, shift any windows that loop, and delete duplicate windows
|
||||||
|
delete_idxs = []
|
||||||
|
win_i = 0
|
||||||
|
while win_i < len(windows):
|
||||||
|
# if window is rolls over itself, need to shift it
|
||||||
|
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames)
|
||||||
|
if is_roll:
|
||||||
|
roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides
|
||||||
|
shift_window_to_end(windows[win_i], num_frames=num_frames)
|
||||||
|
# check if next window (cyclical) is missing roll_val
|
||||||
|
if roll_val not in windows[(win_i+1) % len(windows)]:
|
||||||
|
# need to insert new window here - just insert window starting at roll_val
|
||||||
|
windows.insert(win_i+1, list(range(roll_val, roll_val + handler.context_length)))
|
||||||
|
# delete window if it's not unique
|
||||||
|
for pre_i in range(0, win_i):
|
||||||
|
if windows[win_i] == windows[pre_i]:
|
||||||
|
delete_idxs.append(win_i)
|
||||||
|
break
|
||||||
|
win_i += 1
|
||||||
|
|
||||||
|
# reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation
|
||||||
|
delete_idxs.reverse()
|
||||||
|
for i in delete_idxs:
|
||||||
|
windows.pop(i)
|
||||||
|
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def create_windows_static_standard(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||||
|
windows = []
|
||||||
|
if num_frames <= handler.context_length:
|
||||||
|
windows.append(list(range(num_frames)))
|
||||||
|
return windows
|
||||||
|
# always return the same set of windows
|
||||||
|
delta = handler.context_length - handler.context_overlap
|
||||||
|
for start_idx in range(0, num_frames, delta):
|
||||||
|
# if past the end of frames, move start_idx back to allow same context_length
|
||||||
|
ending = start_idx + handler.context_length
|
||||||
|
if ending >= num_frames:
|
||||||
|
final_delta = ending - num_frames
|
||||||
|
final_start_idx = start_idx - final_delta
|
||||||
|
windows.append(list(range(final_start_idx, final_start_idx + handler.context_length)))
|
||||||
|
break
|
||||||
|
windows.append(list(range(start_idx, start_idx + handler.context_length)))
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def create_windows_batched(num_frames: int, handler: IndexListContextHandler, model_options: dict[str]):
|
||||||
|
windows = []
|
||||||
|
if num_frames <= handler.context_length:
|
||||||
|
windows.append(list(range(num_frames)))
|
||||||
|
return windows
|
||||||
|
# always return the same set of windows;
|
||||||
|
# no overlap, just cut up based on context_length;
|
||||||
|
# last window size will be different if num_frames % opts.context_length != 0
|
||||||
|
for start_idx in range(0, num_frames, handler.context_length):
|
||||||
|
windows.append(list(range(start_idx, min(start_idx + handler.context_length, num_frames))))
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def create_windows_default(num_frames: int, handler: IndexListContextHandler):
|
||||||
|
return [list(range(num_frames))]
|
||||||
|
|
||||||
|
|
||||||
|
CONTEXT_MAPPING = {
|
||||||
|
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped,
|
||||||
|
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard,
|
||||||
|
ContextSchedules.STATIC_STANDARD: create_windows_static_standard,
|
||||||
|
ContextSchedules.BATCHED: create_windows_batched,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||||
|
func = CONTEXT_MAPPING.get(context_schedule, None)
|
||||||
|
if func is None:
|
||||||
|
raise ValueError(f"Unknown context_schedule '{context_schedule}'.")
|
||||||
|
return ContextSchedule(context_schedule, func)
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||||
|
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||||
|
|
||||||
|
|
||||||
|
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||||
|
# weight is the same for all
|
||||||
|
return [1.0] * length
|
||||||
|
|
||||||
|
def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||||
|
# weight is based on the distance away from the edge of the context window;
|
||||||
|
# based on weighted average concept in FreeNoise paper
|
||||||
|
if length % 2 == 0:
|
||||||
|
max_weight = length // 2
|
||||||
|
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))
|
||||||
|
else:
|
||||||
|
max_weight = (length + 1) // 2
|
||||||
|
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||||
|
return weight_sequence
|
||||||
|
|
||||||
|
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||||
|
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||||
|
# only expected overlap is given different weights
|
||||||
|
weights_torch = torch.ones((length))
|
||||||
|
# blend left-side on all except first window
|
||||||
|
if min(idxs) > 0:
|
||||||
|
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||||
|
weights_torch[:handler.context_overlap] = ramp_up
|
||||||
|
# blend right-side on all except last window
|
||||||
|
if max(idxs) < full_length-1:
|
||||||
|
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||||
|
weights_torch[-handler.context_overlap:] = ramp_down
|
||||||
|
return weights_torch
|
||||||
|
|
||||||
|
class ContextFuseMethods:
|
||||||
|
FLAT = "flat"
|
||||||
|
PYRAMID = "pyramid"
|
||||||
|
RELATIVE = "relative"
|
||||||
|
OVERLAP_LINEAR = "overlap-linear"
|
||||||
|
|
||||||
|
LIST = [PYRAMID, FLAT, OVERLAP_LINEAR]
|
||||||
|
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, OVERLAP_LINEAR]
|
||||||
|
|
||||||
|
|
||||||
|
FUSE_MAPPING = {
|
||||||
|
ContextFuseMethods.FLAT: create_weights_flat,
|
||||||
|
ContextFuseMethods.PYRAMID: create_weights_pyramid,
|
||||||
|
ContextFuseMethods.RELATIVE: create_weights_pyramid,
|
||||||
|
ContextFuseMethods.OVERLAP_LINEAR: create_weights_overlap_linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_matching_fuse_method(fuse_method: str) -> ContextFuseMethod:
|
||||||
|
func = FUSE_MAPPING.get(fuse_method, None)
|
||||||
|
if func is None:
|
||||||
|
raise ValueError(f"Unknown fuse_method '{fuse_method}'.")
|
||||||
|
return ContextFuseMethod(fuse_method, func)
|
||||||
|
|
||||||
|
# Returns fraction that has denominator that is a power of 2
|
||||||
|
def ordered_halving(val):
|
||||||
|
# get binary value, padded with 0s for 64 bits
|
||||||
|
bin_str = f"{val:064b}"
|
||||||
|
# flip binary value, padding included
|
||||||
|
bin_flip = bin_str[::-1]
|
||||||
|
# convert binary to int
|
||||||
|
as_int = int(bin_flip, 2)
|
||||||
|
# divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
|
||||||
|
# or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
|
||||||
|
return as_int / (1 << 64)
|
||||||
|
|
||||||
|
|
||||||
|
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]:
|
||||||
|
all_indexes = list(range(num_frames))
|
||||||
|
for w in windows:
|
||||||
|
for val in w:
|
||||||
|
try:
|
||||||
|
all_indexes.remove(val)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return all_indexes
|
||||||
|
|
||||||
|
|
||||||
|
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]:
|
||||||
|
prev_val = -1
|
||||||
|
for i, val in enumerate(window):
|
||||||
|
val = val % num_frames
|
||||||
|
if val < prev_val:
|
||||||
|
return True, i
|
||||||
|
prev_val = val
|
||||||
|
return False, -1
|
||||||
|
|
||||||
|
|
||||||
|
def shift_window_to_start(window: list[int], num_frames: int):
|
||||||
|
start_val = window[0]
|
||||||
|
for i in range(len(window)):
|
||||||
|
# 1) subtract each element by start_val to move vals relative to the start of all frames
|
||||||
|
# 2) add num_frames and take modulus to get adjusted vals
|
||||||
|
window[i] = ((window[i] - start_val) + num_frames) % num_frames
|
||||||
|
|
||||||
|
|
||||||
|
def shift_window_to_end(window: list[int], num_frames: int):
|
||||||
|
# 1) shift window to start
|
||||||
|
shift_window_to_start(window, num_frames)
|
||||||
|
end_val = window[-1]
|
||||||
|
end_delta = num_frames - end_val - 1
|
||||||
|
for i in range(len(window)):
|
||||||
|
# 2) add end_delta to each val to slide windows to end
|
||||||
|
window[i] = window[i] + end_delta
|
||||||
@ -38,6 +38,7 @@ from .ldm.hydit.controlnet import HunYuanControlNet
|
|||||||
from .t2i_adapter import adapter
|
from .t2i_adapter import adapter
|
||||||
from .model_base import convert_tensor
|
from .model_base import convert_tensor
|
||||||
from .model_management import cast_to_device
|
from .model_management import cast_to_device
|
||||||
|
from .ldm.qwen_image.controlnet import QwenImageControlNetModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .hooks import HookGroup
|
from .hooks import HookGroup
|
||||||
@ -240,11 +241,11 @@ class ControlNet(ControlBase):
|
|||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
compression_ratio = self.compression_ratio
|
compression_ratio = self.compression_ratio
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
compression_ratio *= self.vae.downscale_ratio
|
compression_ratio *= self.vae.spacial_compression_encode()
|
||||||
else:
|
else:
|
||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[-1] * compression_ratio, x_noisy.shape[-2] * compression_ratio, self.upscale_algorithm, "center")
|
||||||
self.cond_hint = self.preprocess_image(self.cond_hint)
|
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
loaded_models = model_management.loaded_models(only_currently_used=True)
|
loaded_models = model_management.loaded_models(only_currently_used=True)
|
||||||
@ -657,6 +658,16 @@ def load_controlnet_flux_instantx(sd, model_options=None):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
|
control_model = QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
|
latent_format = comfy.latent_formats.Wan21()
|
||||||
|
extra_conds = []
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
def convert_mistoline(sd):
|
def convert_mistoline(sd):
|
||||||
return utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
return utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
@ -732,8 +743,11 @@ def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_
|
|||||||
return load_controlnet_sd35(controlnet_data, model_options=model_options) # Stability sd3.5 format
|
return load_controlnet_sd35(controlnet_data, model_options=model_options) # Stability sd3.5 format
|
||||||
else:
|
else:
|
||||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) # SD3 diffusers controlnet
|
return load_controlnet_mmdit(controlnet_data, model_options=model_options) # SD3 diffusers controlnet
|
||||||
|
elif "transformer_blocks.0.img_mlp.net.0.proj.weight" in controlnet_data:
|
||||||
|
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: # mistoline flux
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: # mistoline flux
|
||||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|
||||||
|
|||||||
@ -225,19 +225,27 @@ class Flux(nn.Module):
|
|||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
|
index = 0
|
||||||
|
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
|
||||||
for ref in ref_latents:
|
for ref in ref_latents:
|
||||||
h_offset = 0
|
if index_ref_method:
|
||||||
w_offset = 0
|
index += 1
|
||||||
if ref.shape[-2] + h > ref.shape[-1] + w:
|
h_offset = 0
|
||||||
w_offset = w
|
w_offset = 0
|
||||||
else:
|
else:
|
||||||
h_offset = h
|
index = 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
|
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||||
|
w_offset = w
|
||||||
|
else:
|
||||||
|
h_offset = h
|
||||||
|
h = max(h, ref.shape[-2] + h_offset)
|
||||||
|
w = max(w, ref.shape[-1] + w_offset)
|
||||||
|
|
||||||
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
|
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
img = torch.cat([img, kontext], dim=1)
|
img = torch.cat([img, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
h = max(h, ref.shape[-2] + h_offset)
|
|
||||||
w = max(w, ref.shape[-1] + w_offset)
|
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||||
|
|||||||
@ -178,7 +178,7 @@ class FourierEmbedder(nn.Module):
|
|||||||
|
|
||||||
class CrossAttentionProcessor:
|
class CrossAttentionProcessor:
|
||||||
def __call__(self, attn, q, k, v):
|
def __call__(self, attn, q, k, v):
|
||||||
out = F.scaled_dot_product_attention(q, k, v)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -480,7 +480,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
if SDP_BATCH_LIMIT >= b:
|
if SDP_BATCH_LIMIT >= b:
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
@ -493,7 +493,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
if mask.shape[0] > 1:
|
if mask.shape[0] > 1:
|
||||||
m = mask[i: i + SDP_BATCH_LIMIT]
|
m = mask[i: i + SDP_BATCH_LIMIT]
|
||||||
|
|
||||||
out[i: i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
out[i: i + SDP_BATCH_LIMIT] = comfy.ops.scaled_dot_product_attention(
|
||||||
q[i: i + SDP_BATCH_LIMIT],
|
q[i: i + SDP_BATCH_LIMIT],
|
||||||
k[i: i + SDP_BATCH_LIMIT],
|
k[i: i + SDP_BATCH_LIMIT],
|
||||||
v[i: i + SDP_BATCH_LIMIT],
|
v[i: i + SDP_BATCH_LIMIT],
|
||||||
|
|||||||
@ -295,7 +295,7 @@ def pytorch_attention(q, k, v):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||||
out = out.transpose(2, 3).reshape(orig_shape)
|
out = out.transpose(2, 3).reshape(orig_shape)
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logger.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
logger.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||||
|
|||||||
77
comfy/ldm/qwen_image/controlnet.py
Normal file
77
comfy/ldm/qwen_image/controlnet.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
from .model import QwenImageTransformer2DModel
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
extra_condition_channels=0,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
self.main_model_double = 60
|
||||||
|
|
||||||
|
# controlnet_blocks
|
||||||
|
self.controlnet_blocks = torch.nn.ModuleList([])
|
||||||
|
for _ in range(len(self.transformer_blocks)):
|
||||||
|
self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype))
|
||||||
|
self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
attention_mask=None,
|
||||||
|
guidance: torch.Tensor = None,
|
||||||
|
ref_latents=None,
|
||||||
|
hint=None,
|
||||||
|
transformer_options={},
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
timestep = timesteps
|
||||||
|
encoder_hidden_states = context
|
||||||
|
encoder_hidden_states_mask = attention_mask
|
||||||
|
|
||||||
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||||
|
hint, _, _ = self.process_img(hint)
|
||||||
|
|
||||||
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||||
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
|
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
||||||
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
|
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
|
if guidance is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
|
|
||||||
|
temb = (
|
||||||
|
self.time_text_embed(timestep, hidden_states)
|
||||||
|
if guidance is None
|
||||||
|
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks))
|
||||||
|
|
||||||
|
controlnet_block_samples = ()
|
||||||
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
encoder_hidden_states, hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat
|
||||||
|
|
||||||
|
return {"input": controlnet_block_samples[:self.main_model_double]}
|
||||||
@ -294,13 +294,14 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
guidance_embeds: bool = False,
|
guidance_embeds: bool = False,
|
||||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||||
image_model=None,
|
image_model=None,
|
||||||
dtype=None,
|
final_layer=True, dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels or in_channels
|
self.out_channels = out_channels or in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
@ -330,25 +331,29 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
if final_layer:
|
||||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
self.gradient_checkpointing = False
|
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
def pos_embeds(self, x, context):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
|
hidden_states = pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||||
|
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
|
||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
||||||
|
|
||||||
txt_start = round(max(h_len, w_len))
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
|
||||||
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3)
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
||||||
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||||
|
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -357,19 +362,48 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
context,
|
context,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
guidance: torch.Tensor = None,
|
guidance: torch.Tensor = None,
|
||||||
|
ref_latents=None,
|
||||||
|
transformer_options={},
|
||||||
|
control=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
timestep = timesteps
|
timestep = timesteps
|
||||||
encoder_hidden_states = context
|
encoder_hidden_states = context
|
||||||
encoder_hidden_states_mask = attention_mask
|
encoder_hidden_states_mask = attention_mask
|
||||||
|
|
||||||
image_rotary_emb = self.pos_embeds(x, context)
|
hidden_states, img_ids, orig_shape = self.process_img(x)
|
||||||
|
num_embeds = hidden_states.shape[1]
|
||||||
|
|
||||||
hidden_states = pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
if ref_latents is not None:
|
||||||
orig_shape = hidden_states.shape
|
h = 0
|
||||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
w = 0
|
||||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
index = 0
|
||||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
|
||||||
|
for ref in ref_latents:
|
||||||
|
if index_ref_method:
|
||||||
|
index += 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
|
else:
|
||||||
|
index = 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
|
if ref.shape[-2] + h > ref.shape[-1] + w:
|
||||||
|
w_offset = w
|
||||||
|
else:
|
||||||
|
h_offset = h
|
||||||
|
h = max(h, ref.shape[-2] + h_offset)
|
||||||
|
w = max(w, ref.shape[-1] + w_offset)
|
||||||
|
|
||||||
|
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||||
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
|
||||||
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||||
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
@ -384,18 +418,45 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
else self.time_text_embed(timestep, guidance, hidden_states)
|
||||||
)
|
)
|
||||||
|
|
||||||
for block in self.transformer_blocks:
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
encoder_hidden_states, hidden_states = block(
|
patches = transformer_options.get("patches", {})
|
||||||
hidden_states=hidden_states,
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
temb=temb,
|
if ("double_block", i) in blocks_replace:
|
||||||
image_rotary_emb=image_rotary_emb,
|
def block_wrap(args):
|
||||||
)
|
out = {}
|
||||||
|
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
|
||||||
|
hidden_states = out["img"]
|
||||||
|
encoder_hidden_states = out["txt"]
|
||||||
|
else:
|
||||||
|
encoder_hidden_states, hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "double_block" in patches:
|
||||||
|
for p in patches["double_block"]:
|
||||||
|
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
|
||||||
|
hidden_states = out["img"]
|
||||||
|
encoder_hidden_states = out["txt"]
|
||||||
|
|
||||||
|
if control is not None: # Controlnet
|
||||||
|
control_i = control.get("input")
|
||||||
|
if i < len(control_i):
|
||||||
|
add = control_i[i]
|
||||||
|
if add is not None:
|
||||||
|
hidden_states += add
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
||||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||||
|
|||||||
@ -391,6 +391,7 @@ class WanModel(torch.nn.Module):
|
|||||||
cross_attn_norm=True,
|
cross_attn_norm=True,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
flf_pos_embed_token_number=None,
|
flf_pos_embed_token_number=None,
|
||||||
|
in_dim_ref_conv=None,
|
||||||
image_model=None,
|
image_model=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -484,6 +485,11 @@ class WanModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.img_emb = None
|
self.img_emb = None
|
||||||
|
|
||||||
|
if in_dim_ref_conv is not None:
|
||||||
|
self.ref_conv = operations.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:], device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
else:
|
||||||
|
self.ref_conv = None
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -526,6 +532,13 @@ class WanModel(torch.nn.Module):
|
|||||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
|
||||||
|
full_ref = None
|
||||||
|
if self.ref_conv is not None:
|
||||||
|
full_ref = kwargs.get("reference_latent", None)
|
||||||
|
if full_ref is not None:
|
||||||
|
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||||
|
x = torch.concat((full_ref, x), dim=1)
|
||||||
|
|
||||||
# context
|
# context
|
||||||
context = self.text_embedding(context)
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
@ -552,6 +565,9 @@ class WanModel(torch.nn.Module):
|
|||||||
# head
|
# head
|
||||||
x = self.head(x, e)
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
if full_ref is not None:
|
||||||
|
x = x[:, full_ref.shape[1]:]
|
||||||
|
|
||||||
# unpatchify
|
# unpatchify
|
||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
@ -570,6 +586,9 @@ class WanModel(torch.nn.Module):
|
|||||||
x = torch.cat([x, time_dim_concat], dim=2)
|
x = torch.cat([x, time_dim_concat], dim=2)
|
||||||
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
||||||
|
|
||||||
|
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||||
|
t_len += 1
|
||||||
|
|
||||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||||
@ -749,7 +768,12 @@ class CameraWanModel(WanModel):
|
|||||||
operations=None,
|
operations=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
if model_type == 'camera':
|
||||||
|
model_type = 'i2v'
|
||||||
|
else:
|
||||||
|
model_type = 't2v'
|
||||||
|
|
||||||
|
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
|
||||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
|
||||||
|
|||||||
@ -313,6 +313,7 @@ def model_lora_keys_unet(model, key_map=None):
|
|||||||
key_map["{}".format(key_lora)] = k
|
key_map["{}".format(key_lora)] = k
|
||||||
# Support transformer prefix format
|
# Support transformer prefix format
|
||||||
key_map["transformer.{}".format(key_lora)] = k
|
key_map["transformer.{}".format(key_lora)] = k
|
||||||
|
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|||||||
@ -928,6 +928,10 @@ class Flux(BaseModel):
|
|||||||
for lat in ref_latents:
|
for lat in ref_latents:
|
||||||
latents.append(self.process_latent_in(lat))
|
latents.append(self.process_latent_in(lat))
|
||||||
out['ref_latents'] = conds.CONDList(latents)
|
out['ref_latents'] = conds.CONDList(latents)
|
||||||
|
|
||||||
|
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||||
|
if ref_latents_method is not None:
|
||||||
|
out['ref_latents_method'] = conds.CONDConstant(ref_latents_method)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def extra_conds_shapes(self, **kwargs):
|
def extra_conds_shapes(self, **kwargs):
|
||||||
@ -1169,7 +1173,11 @@ class WAN21(BaseModel):
|
|||||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||||
|
|
||||||
return torch.cat((mask, image), dim=1)
|
concat_mask_index = kwargs.get("concat_mask_index", 0)
|
||||||
|
if concat_mask_index != 0:
|
||||||
|
return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1)
|
||||||
|
else:
|
||||||
|
return torch.cat((mask, image), dim=1)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@ -1184,6 +1192,10 @@ class WAN21(BaseModel):
|
|||||||
time_dim_concat = kwargs.get("time_dim_concat", None)
|
time_dim_concat = kwargs.get("time_dim_concat", None)
|
||||||
if time_dim_concat is not None:
|
if time_dim_concat is not None:
|
||||||
out['time_dim_concat'] = conds.CONDRegular(self.process_latent_in(time_dim_concat))
|
out['time_dim_concat'] = conds.CONDRegular(self.process_latent_in(time_dim_concat))
|
||||||
|
|
||||||
|
reference_latents = kwargs.get("reference_latents", None)
|
||||||
|
if reference_latents is not None:
|
||||||
|
out['reference_latent'] = conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -1365,10 +1377,28 @@ class Omnigen2(BaseModel):
|
|||||||
class QwenImage(BaseModel):
|
class QwenImage(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=QwenImageTransformer2DModel)
|
super().__init__(model_config, model_type, device=device, unet_model=QwenImageTransformer2DModel)
|
||||||
|
self.memory_usage_factor_conds = ("ref_latents",)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = conds.CONDRegular(cross_attn)
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
latents = []
|
||||||
|
for lat in ref_latents:
|
||||||
|
latents.append(self.process_latent_in(lat))
|
||||||
|
out['ref_latents'] = conds.CONDList(latents)
|
||||||
|
|
||||||
|
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||||
|
if ref_latents_method is not None:
|
||||||
|
out['ref_latents_method'] = conds.CONDConstant(ref_latents_method)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -370,7 +370,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
|
||||||
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
|
||||||
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "camera"
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
|
dit_config["model_type"] = "camera"
|
||||||
|
else:
|
||||||
|
dit_config["model_type"] = "camera_2.2"
|
||||||
else:
|
else:
|
||||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["model_type"] = "i2v"
|
dit_config["model_type"] = "i2v"
|
||||||
@ -379,6 +382,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
|
||||||
if flf_weight is not None:
|
if flf_weight is not None:
|
||||||
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
|
||||||
|
|
||||||
|
ref_conv_weight = state_dict.get('{}ref_conv.weight'.format(key_prefix))
|
||||||
|
if ref_conv_weight is not None:
|
||||||
|
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||||
@ -490,6 +498,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "qwen_image"
|
dit_config["image_model"] = "qwen_image"
|
||||||
|
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
|
|||||||
@ -102,7 +102,6 @@ try:
|
|||||||
torch_version = torch.version.__version__
|
torch_version = torch.version.__version__
|
||||||
temp = torch_version.split(".")
|
temp = torch_version.split(".")
|
||||||
torch_version_numeric = (int(temp[0]), int(temp[1]))
|
torch_version_numeric = (int(temp[0]), int(temp[1]))
|
||||||
xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
|
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -126,11 +125,14 @@ if args.directml is not None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, noqa: F401
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, noqa: F401
|
||||||
|
|
||||||
_ = torch.xpu.device_count()
|
|
||||||
xpu_available = xpu_available or torch.xpu.is_available()
|
|
||||||
except:
|
except:
|
||||||
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = torch.xpu.device_count()
|
||||||
|
xpu_available = torch.xpu.is_available()
|
||||||
|
except:
|
||||||
|
xpu_available = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
@ -369,9 +371,9 @@ try:
|
|||||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch_version_numeric >= (2, 8):
|
# if torch_version_numeric >= (2, 8):
|
||||||
if any((a in arch) for a in ["gfx1201"]):
|
# if any((a in arch) for a in ["gfx1201"]):
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
# ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||||
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||||
SUPPORT_FP8_OPS = True
|
SUPPORT_FP8_OPS = True
|
||||||
@ -386,7 +388,7 @@ if ENABLE_PYTORCH_ATTENTION:
|
|||||||
|
|
||||||
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
|
PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
|
||||||
try:
|
try:
|
||||||
if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
|
if (is_nvidia() or is_amd()) and PerformanceFeature.Fp16Accumulation in args.fast:
|
||||||
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
||||||
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
|
PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
|
||||||
logger.info("Enabled fp16 accumulation.")
|
logger.info("Enabled fp16 accumulation.")
|
||||||
@ -682,7 +684,13 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
|
|||||||
else:
|
else:
|
||||||
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||||
|
|
||||||
models = set(models)
|
models_temp = set()
|
||||||
|
for m in models:
|
||||||
|
models_temp.add(m)
|
||||||
|
for mm in m.model_patches_models():
|
||||||
|
models_temp.add(mm)
|
||||||
|
|
||||||
|
models = models_temp
|
||||||
|
|
||||||
models_to_load = []
|
models_to_load = []
|
||||||
models_freed = []
|
models_freed = []
|
||||||
@ -1063,10 +1071,12 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None):
|
|||||||
def device_supports_non_blocking(device):
|
def device_supports_non_blocking(device):
|
||||||
if torch.jit.is_tracing() or torch.jit.is_scripting():
|
if torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||||
return True
|
return True
|
||||||
|
if args.force_non_blocking:
|
||||||
|
return True
|
||||||
if is_device_mps(device):
|
if is_device_mps(device):
|
||||||
return False # pytorch bug? mps doesn't support non blocking
|
return False # pytorch bug? mps doesn't support non blocking
|
||||||
if is_intel_xpu():
|
if is_intel_xpu(): #xpu does support non blocking but it is slower on iGPUs for some reason so disable by default until situation changes
|
||||||
return True
|
return False
|
||||||
if args.deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
if args.deterministic: # TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
|
||||||
return False
|
return False
|
||||||
if directml_device:
|
if directml_device:
|
||||||
@ -1441,10 +1451,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
if torch_version_numeric < (2, 6):
|
if torch_version_numeric < (2, 3):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return torch.xpu.get_device_capability(device)['has_bfloat16_conversions']
|
return torch.xpu.is_bf16_supported()
|
||||||
|
|
||||||
if is_ascend_npu():
|
if is_ascend_npu():
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -485,6 +485,9 @@ class ModelPatcher(ModelManageable):
|
|||||||
def set_model_forward_timestep_embed_patch(self, patch):
|
def set_model_forward_timestep_embed_patch(self, patch):
|
||||||
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
self.set_model_patch(patch, "forward_timestep_embed_patch")
|
||||||
|
|
||||||
|
def set_model_double_block_patch(self, patch):
|
||||||
|
self.set_model_patch(patch, "double_block")
|
||||||
|
|
||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
@ -553,6 +556,30 @@ class ModelPatcher(ModelManageable):
|
|||||||
if hasattr(wrap_func, "to"):
|
if hasattr(wrap_func, "to"):
|
||||||
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
||||||
|
|
||||||
|
def model_patches_models(self):
|
||||||
|
to = self.model_options["transformer_options"]
|
||||||
|
models = []
|
||||||
|
if "patches" in to:
|
||||||
|
patches = to["patches"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for i in range(len(patch_list)):
|
||||||
|
if hasattr(patch_list[i], "models"):
|
||||||
|
models += patch_list[i].models()
|
||||||
|
if "patches_replace" in to:
|
||||||
|
patches = to["patches_replace"]
|
||||||
|
for name in patches:
|
||||||
|
patch_list = patches[name]
|
||||||
|
for k in patch_list:
|
||||||
|
if hasattr(patch_list[k], "models"):
|
||||||
|
models += patch_list[k].models()
|
||||||
|
if "model_function_wrapper" in self.model_options:
|
||||||
|
wrap_func = self.model_options["model_function_wrapper"]
|
||||||
|
if hasattr(wrap_func, "models"):
|
||||||
|
models += wrap_func.models()
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
# this pokes into the internals of diffusion model a little bit
|
# this pokes into the internals of diffusion model a little bit
|
||||||
# todo: the base model isn't going to be aware that its diffusion model is patched this way
|
# todo: the base model isn't going to be aware that its diffusion model is patched this way
|
||||||
|
|||||||
27
comfy/ops.py
27
comfy/ops.py
@ -26,11 +26,36 @@ from .cli_args import args, PerformanceFeature
|
|||||||
from .execution_context import current_execution_context
|
from .execution_context import current_execution_context
|
||||||
from .float import stochastic_rounding
|
from .float import stochastic_rounding
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||||
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
|
import inspect
|
||||||
|
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||||
|
SDPA_BACKEND_PRIORITY = [
|
||||||
|
SDPBackend.FLASH_ATTENTION,
|
||||||
|
SDPBackend.EFFICIENT_ATTENTION,
|
||||||
|
SDPBackend.MATH,
|
||||||
|
]
|
||||||
|
|
||||||
|
SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION)
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||||
|
with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True):
|
||||||
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
logging.warning("Torch version too old to set sdpa backend priority.")
|
||||||
|
except (ModuleNotFoundError, TypeError):
|
||||||
|
logging.warning("Could not set sdpa backend priority.")
|
||||||
|
|
||||||
cast_to = model_management.cast_to # TODO: remove once no more references
|
cast_to = model_management.cast_to # TODO: remove once no more references
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from .model_management import cast_to
|
from .model_management import cast_to
|
||||||
import numbers
|
import numbers
|
||||||
|
import logging
|
||||||
|
|
||||||
RMSNorm = None
|
RMSNorm = None
|
||||||
|
|
||||||
@ -9,6 +10,7 @@ try:
|
|||||||
RMSNorm = torch.nn.RMSNorm
|
RMSNorm = torch.nn.RMSNorm
|
||||||
except:
|
except:
|
||||||
rms_norm_torch = None
|
rms_norm_torch = None
|
||||||
|
logging.warning("Please update pytorch to use native RMSNorm")
|
||||||
|
|
||||||
|
|
||||||
def rms_norm(x, weight=None, eps=1e-6):
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
|
|||||||
@ -163,7 +163,7 @@ def cleanup_models(conds, models):
|
|||||||
cleanup_additional_models(set(control_cleanup))
|
cleanup_additional_models(set(control_cleanup))
|
||||||
|
|
||||||
|
|
||||||
def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
|
||||||
'''
|
'''
|
||||||
Registers hooks from conds.
|
Registers hooks from conds.
|
||||||
'''
|
'''
|
||||||
@ -172,8 +172,8 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
|
|||||||
for k in conds:
|
for k in conds:
|
||||||
get_hooks_from_cond(conds[k], hooks)
|
get_hooks_from_cond(conds[k], hooks)
|
||||||
# add wrappers and callbacks from ModelPatcher to transformer_options
|
# add wrappers and callbacks from ModelPatcher to transformer_options
|
||||||
model_options["transformer_options"]["wrappers"] = patcher_extension.copy_nested_dicts(model.wrappers)
|
patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("wrappers", {}), model.wrappers, copy_dict1=False)
|
||||||
model_options["transformer_options"]["callbacks"] = patcher_extension.copy_nested_dicts(model.callbacks)
|
patcher_extension.merge_nested_dicts(model_options["transformer_options"].setdefault("callbacks", {}), model.callbacks, copy_dict1=False)
|
||||||
# begin registering hooks
|
# begin registering hooks
|
||||||
registered = HookGroup()
|
registered = HookGroup()
|
||||||
target_dict = create_target_dict(EnumWeightTarget.Model)
|
target_dict = create_target_dict(EnumWeightTarget.Model)
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from .model_base import BaseModel
|
|||||||
from .model_management_types import ModelOptions
|
from .model_management_types import ModelOptions
|
||||||
from .model_patcher import ModelPatcher
|
from .model_patcher import ModelPatcher
|
||||||
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
|
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES
|
||||||
|
from .context_windows import ContextHandlerABC
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -32,6 +33,7 @@ def add_area_dims(area, num_dims):
|
|||||||
area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:]
|
area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:]
|
||||||
return area
|
return area
|
||||||
|
|
||||||
|
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
dims = tuple(x_in.shape[2:])
|
dims = tuple(x_in.shape[2:])
|
||||||
area = None
|
area = None
|
||||||
@ -210,7 +212,14 @@ def finalize_default_conds(model: BaseModel, hooked_to_run: dict[HookGroup, list
|
|||||||
hooked_to_run[p.hooks] += [(p, i)]
|
hooked_to_run[p.hooks] += [(p, i)]
|
||||||
|
|
||||||
|
|
||||||
def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
def calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options: dict[str]):
|
||||||
|
handler: ContextHandlerABC = model_options.get("context_handler", None)
|
||||||
|
if handler is None or not handler.should_use_context(model, conds, x_in, timestep, model_options):
|
||||||
|
return _calc_cond_batch_outer(model, conds, x_in, timestep, model_options)
|
||||||
|
return handler.execute(_calc_cond_batch_outer, model, conds, x_in, timestep, model_options)
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options):
|
||||||
executor = patcher_extension.WrapperExecutor.new_executor(
|
executor = patcher_extension.WrapperExecutor.new_executor(
|
||||||
_calc_cond_batch,
|
_calc_cond_batch,
|
||||||
patcher_extension.get_all_wrappers(patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
patcher_extension.get_all_wrappers(patcher_extension.WrappersMP.CALC_COND_BATCH, model_options, is_model_options=True)
|
||||||
@ -754,6 +763,7 @@ class Sampler:
|
|||||||
sigma = float(sigmas[0])
|
sigma = float(sigmas[0])
|
||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
self.sampler_function = sampler_function
|
self.sampler_function = sampler_function
|
||||||
|
|||||||
@ -229,17 +229,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
|
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
|
||||||
index = 0
|
index = 0
|
||||||
pad_extra = 0
|
pad_extra = 0
|
||||||
|
embeds_info = []
|
||||||
for o in other_embeds:
|
for o in other_embeds:
|
||||||
emb = o[1]
|
emb = o[1]
|
||||||
if torch.is_tensor(emb):
|
if torch.is_tensor(emb):
|
||||||
emb = {"type": "embedding", "data": emb}
|
emb = {"type": "embedding", "data": emb}
|
||||||
|
|
||||||
|
extra = None
|
||||||
emb_type = emb.get("type", None)
|
emb_type = emb.get("type", None)
|
||||||
if emb_type == "embedding":
|
if emb_type == "embedding":
|
||||||
emb = emb.get("data", None)
|
emb = emb.get("data", None)
|
||||||
else:
|
else:
|
||||||
if hasattr(self.transformer, "preprocess_embed"):
|
if hasattr(self.transformer, "preprocess_embed"):
|
||||||
emb = self.transformer.preprocess_embed(emb, device=device)
|
emb, extra = self.transformer.preprocess_embed(emb, device=device)
|
||||||
else:
|
else:
|
||||||
emb = None
|
emb = None
|
||||||
|
|
||||||
@ -254,6 +256,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
|
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
|
||||||
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
|
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
|
||||||
index += emb_shape - 1
|
index += emb_shape - 1
|
||||||
|
embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra})
|
||||||
else:
|
else:
|
||||||
index += -1
|
index += -1
|
||||||
pad_extra += emb_shape
|
pad_extra += emb_shape
|
||||||
@ -268,11 +271,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
attention_masks.append(attention_mask)
|
attention_masks.append(attention_mask)
|
||||||
num_tokens.append(sum(attention_mask))
|
num_tokens.append(sum(attention_mask))
|
||||||
|
|
||||||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
|
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
|
||||||
|
|
||||||
def forward(self, tokens):
|
def forward(self, tokens):
|
||||||
device = self.transformer.get_input_embeddings().weight.device
|
device = self.transformer.get_input_embeddings().weight.device
|
||||||
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
|
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
|
||||||
|
|
||||||
attention_mask_model = None
|
attention_mask_model = None
|
||||||
if self.enable_attention_masks:
|
if self.enable_attention_masks:
|
||||||
@ -283,7 +286,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
else:
|
else:
|
||||||
intermediate_output = self.layer_idx
|
intermediate_output = self.layer_idx
|
||||||
|
|
||||||
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
|
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info)
|
||||||
|
|
||||||
if self.layer == "last":
|
if self.layer == "last":
|
||||||
z = outputs[0].float()
|
z = outputs[0].float()
|
||||||
@ -644,7 +647,10 @@ class SDTokenizer:
|
|||||||
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
|
||||||
|
|
||||||
text = escape_important(text)
|
text = escape_important(text)
|
||||||
parsed_weights = token_weights(text, 1.0)
|
if kwargs.get("disable_weights", False):
|
||||||
|
parsed_weights = [(text, 1.0)]
|
||||||
|
else:
|
||||||
|
parsed_weights = token_weights(text, 1.0)
|
||||||
vocab = self.tokenizer.get_vocab()
|
vocab = self.tokenizer.get_vocab()
|
||||||
|
|
||||||
# tokenize words
|
# tokenize words
|
||||||
|
|||||||
@ -1129,6 +1129,18 @@ class WAN21_Camera(WAN21_T2V):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class WAN22_Camera(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "camera_2.2",
|
||||||
|
"in_dim": 36,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class WAN21_Vace(WAN21_T2V):
|
class WAN21_Vace(WAN21_T2V):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@ -1327,6 +1339,7 @@ class Omnigen2(supported_models_base.BASE):
|
|||||||
hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(omnigen2.Omnigen2Tokenizer, omnigen2.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(omnigen2.Omnigen2Tokenizer, omnigen2.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
class QwenImage(supported_models_base.BASE):
|
class QwenImage(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "qwen_image",
|
"image_model": "qwen_image",
|
||||||
@ -1337,7 +1350,7 @@ class QwenImage(supported_models_base.BASE):
|
|||||||
"shift": 1.15,
|
"shift": 1.15,
|
||||||
}
|
}
|
||||||
|
|
||||||
memory_usage_factor = 1.8 #TODO
|
memory_usage_factor = 1.8 # TODO
|
||||||
|
|
||||||
unet_extra_config = {}
|
unet_extra_config = {}
|
||||||
latent_format = latent_formats.Wan21
|
latent_format = latent_formats.Wan21
|
||||||
@ -1357,6 +1370,6 @@ class QwenImage(supported_models_base.BASE):
|
|||||||
return supported_models_base.ClipTarget(qwen_image.QwenImageTokenizer, qwen_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(qwen_image.QwenImageTokenizer, qwen_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -118,7 +118,7 @@ class BertModel_(torch.nn.Module):
|
|||||||
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
|
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
|
||||||
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
|
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
|
||||||
|
|
||||||
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||||
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
|
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
|
||||||
mask = None
|
mask = None
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
import torch
|
import math
|
||||||
import torch.nn as nn
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Any
|
from typing import Optional, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from . import qwen_vl
|
||||||
from ..ldm.common_dit import rms_norm
|
from ..ldm.common_dit import rms_norm
|
||||||
from ..ldm.modules.attention import optimized_attention_for_device
|
from ..ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
@ -23,6 +26,7 @@ class Llama2Config:
|
|||||||
rms_norm_add = False
|
rms_norm_add = False
|
||||||
mlp_activation = "silu"
|
mlp_activation = "silu"
|
||||||
qkv_bias = False
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -41,6 +45,7 @@ class Qwen25_3BConfig:
|
|||||||
rms_norm_add = False
|
rms_norm_add = False
|
||||||
mlp_activation = "silu"
|
mlp_activation = "silu"
|
||||||
qkv_bias = True
|
qkv_bias = True
|
||||||
|
rope_dims = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -59,6 +64,7 @@ class Qwen25_7BVLI_Config:
|
|||||||
rms_norm_add = False
|
rms_norm_add = False
|
||||||
mlp_activation = "silu"
|
mlp_activation = "silu"
|
||||||
qkv_bias = True
|
qkv_bias = True
|
||||||
|
rope_dims = [16, 24, 24]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -77,6 +83,7 @@ class Gemma2_2B_Config:
|
|||||||
rms_norm_add = True
|
rms_norm_add = True
|
||||||
mlp_activation = "gelu_pytorch_tanh"
|
mlp_activation = "gelu_pytorch_tanh"
|
||||||
qkv_bias = False
|
qkv_bias = False
|
||||||
|
rope_dims = None
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@ -101,24 +108,30 @@ def rotate_half(x):
|
|||||||
return torch.cat((-x2, x1), dim=-1)
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
|
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
|
||||||
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
|
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
|
||||||
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
||||||
|
|
||||||
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
|
|
||||||
|
|
||||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||||
position_ids_expanded = position_ids[:, None, :].float()
|
position_ids_expanded = position_ids[:, None, :].float()
|
||||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
cos = emb.cos()
|
cos = emb.cos()
|
||||||
sin = emb.sin()
|
sin = emb.sin()
|
||||||
|
if rope_dims is not None and position_ids.shape[0] > 1:
|
||||||
|
mrope_section = rope_dims * 2
|
||||||
|
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||||
|
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
cos = cos.unsqueeze(1)
|
||||||
|
sin = sin.unsqueeze(1)
|
||||||
|
|
||||||
return (cos, sin)
|
return (cos, sin)
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(xq, xk, freqs_cis):
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
cos = freqs_cis[0].unsqueeze(1)
|
cos = freqs_cis[0]
|
||||||
sin = freqs_cis[1].unsqueeze(1)
|
sin = freqs_cis[1]
|
||||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
||||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
@ -282,7 +295,7 @@ class Llama2_(nn.Module):
|
|||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||||
if embeds is not None:
|
if embeds is not None:
|
||||||
x = embeds
|
x = embeds
|
||||||
else:
|
else:
|
||||||
@ -291,9 +304,13 @@ class Llama2_(nn.Module):
|
|||||||
if self.normalize_in:
|
if self.normalize_in:
|
||||||
x *= self.config.hidden_size ** 0.5
|
x *= self.config.hidden_size ** 0.5
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
|
||||||
|
|
||||||
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
freqs_cis = precompute_freqs_cis(self.config.head_dim,
|
||||||
x.shape[1],
|
position_ids,
|
||||||
self.config.rope_theta,
|
self.config.rope_theta,
|
||||||
|
self.config.rope_dims,
|
||||||
device=x.device)
|
device=x.device)
|
||||||
|
|
||||||
mask = None
|
mask = None
|
||||||
@ -382,8 +399,37 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
|||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||||
|
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def preprocess_embed(self, embed, device):
|
||||||
|
if embed["type"] == "image":
|
||||||
|
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
|
||||||
|
return self.visual(image.to(device, dtype=torch.float32), grid), grid
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||||
|
grid = None
|
||||||
|
for e in embeds_info:
|
||||||
|
if e.get("type") == "image":
|
||||||
|
grid = e.get("extra", None)
|
||||||
|
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||||
|
start = e.get("index")
|
||||||
|
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||||
|
end = e.get("size") + start
|
||||||
|
len_max = int(grid.max()) // 2
|
||||||
|
start_next = len_max + start
|
||||||
|
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
|
||||||
|
position_ids[0, start:end] = start
|
||||||
|
max_d = int(grid[0][1]) // 2
|
||||||
|
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||||
|
max_d = int(grid[0][2]) // 2
|
||||||
|
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||||
|
|
||||||
|
if grid is None:
|
||||||
|
position_ids = None
|
||||||
|
|
||||||
|
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
|
||||||
|
|
||||||
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||||
def __init__(self, config_dict, dtype, device, operations):
|
def __init__(self, config_dict, dtype, device, operations):
|
||||||
|
|||||||
@ -21,13 +21,27 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
tokenizer_data = {}
|
tokenizer_data = {}
|
||||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
|
||||||
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
|
||||||
if llama_template is None:
|
if llama_template is None:
|
||||||
llama_text = self.llama_template.format(text)
|
if len(images) > 0:
|
||||||
|
llama_text = self.llama_template_images.format(text)
|
||||||
|
else:
|
||||||
|
llama_text = self.llama_template.format(text)
|
||||||
else:
|
else:
|
||||||
llama_text = llama_template.format(text)
|
llama_text = llama_template.format(text)
|
||||||
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
|
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||||
|
key_name = next(iter(tokens))
|
||||||
|
embed_count = 0
|
||||||
|
qwen_tokens = tokens[key_name]
|
||||||
|
for r in qwen_tokens:
|
||||||
|
for i in range(len(r)):
|
||||||
|
if r[i][0] == 151655:
|
||||||
|
if len(images) > embed_count:
|
||||||
|
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||||
|
embed_count += 1
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||||
|
|||||||
428
comfy/text_encoders/qwen_vl.py
Normal file
428
comfy/text_encoders/qwen_vl.py
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import math
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
|
|
||||||
|
def process_qwen2vl_images(
|
||||||
|
images: torch.Tensor,
|
||||||
|
min_pixels: int = 3136,
|
||||||
|
max_pixels: int = 12845056,
|
||||||
|
patch_size: int = 14,
|
||||||
|
temporal_patch_size: int = 2,
|
||||||
|
merge_size: int = 2,
|
||||||
|
image_mean: list = None,
|
||||||
|
image_std: list = None,
|
||||||
|
):
|
||||||
|
if image_mean is None:
|
||||||
|
image_mean = [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
if image_std is None:
|
||||||
|
image_std = [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
|
||||||
|
batch_size, height, width, channels = images.shape
|
||||||
|
device = images.device
|
||||||
|
# dtype = images.dtype
|
||||||
|
|
||||||
|
images = images.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
grid_thw_list = []
|
||||||
|
img = images[0]
|
||||||
|
|
||||||
|
factor = patch_size * merge_size
|
||||||
|
|
||||||
|
h_bar = round(height / factor) * factor
|
||||||
|
w_bar = round(width / factor) * factor
|
||||||
|
|
||||||
|
if h_bar * w_bar > max_pixels:
|
||||||
|
beta = math.sqrt((height * width) / max_pixels)
|
||||||
|
h_bar = max(factor, math.floor(height / beta / factor) * factor)
|
||||||
|
w_bar = max(factor, math.floor(width / beta / factor) * factor)
|
||||||
|
elif h_bar * w_bar < min_pixels:
|
||||||
|
beta = math.sqrt(min_pixels / (height * width))
|
||||||
|
h_bar = math.ceil(height * beta / factor) * factor
|
||||||
|
w_bar = math.ceil(width * beta / factor) * factor
|
||||||
|
|
||||||
|
img_resized = F.interpolate(
|
||||||
|
img.unsqueeze(0),
|
||||||
|
size=(h_bar, w_bar),
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
normalized = img_resized.clone()
|
||||||
|
for c in range(3):
|
||||||
|
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
|
||||||
|
|
||||||
|
grid_h = h_bar // patch_size
|
||||||
|
grid_w = w_bar // patch_size
|
||||||
|
grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
pixel_values = normalized
|
||||||
|
grid_thw_list.append(grid_thw)
|
||||||
|
image_grid_thw = torch.stack(grid_thw_list)
|
||||||
|
|
||||||
|
grid_t = 1
|
||||||
|
channel = pixel_values.shape[0]
|
||||||
|
pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1)
|
||||||
|
|
||||||
|
patches = pixel_values.reshape(
|
||||||
|
grid_t,
|
||||||
|
temporal_patch_size,
|
||||||
|
channel,
|
||||||
|
grid_h // merge_size,
|
||||||
|
merge_size,
|
||||||
|
patch_size,
|
||||||
|
grid_w // merge_size,
|
||||||
|
merge_size,
|
||||||
|
patch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||||
|
flatten_patches = patches.reshape(
|
||||||
|
grid_t * grid_h * grid_w,
|
||||||
|
channel * temporal_patch_size * patch_size * patch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return flatten_patches, image_grid_thw
|
||||||
|
|
||||||
|
|
||||||
|
class VisionPatchEmbed(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
patch_size: int = 14,
|
||||||
|
temporal_patch_size: int = 2,
|
||||||
|
in_channels: int = 3,
|
||||||
|
embed_dim: int = 3584,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
ops=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.temporal_patch_size = temporal_patch_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
||||||
|
self.proj = ops.Conv3d(
|
||||||
|
in_channels,
|
||||||
|
embed_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=kernel_size,
|
||||||
|
bias=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
||||||
|
)
|
||||||
|
hidden_states = self.proj(hidden_states)
|
||||||
|
return hidden_states.view(-1, self.embed_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb_vision(q, k, cos, sin):
|
||||||
|
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class VisionRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim: int, theta: float = 10000.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.theta = theta
|
||||||
|
|
||||||
|
def forward(self, seqlen: int, device) -> torch.Tensor:
|
||||||
|
inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim))
|
||||||
|
seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype)
|
||||||
|
freqs = torch.outer(seq, inv_freq)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
class PatchMerger(nn.Module):
|
||||||
|
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = context_dim * (spatial_merge_size ** 2)
|
||||||
|
self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype),
|
||||||
|
nn.GELU(),
|
||||||
|
ops.Linear(self.hidden_size, dim, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = self.ln_q(x).reshape(-1, self.hidden_size)
|
||||||
|
x = self.mlp(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VisionAttention(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = hidden_size // num_heads
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype)
|
||||||
|
self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
cu_seqlens=None,
|
||||||
|
optimized_attention=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if hidden_states.dim() == 2:
|
||||||
|
seq_length, _ = hidden_states.shape
|
||||||
|
batch_size = 1
|
||||||
|
hidden_states = hidden_states.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
batch_size, seq_length, _ = hidden_states.shape
|
||||||
|
|
||||||
|
qkv = self.qkv(hidden_states)
|
||||||
|
qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim)
|
||||||
|
query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
query_states = query_states.transpose(0, 1).unsqueeze(0)
|
||||||
|
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
||||||
|
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
||||||
|
|
||||||
|
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||||
|
splits = [
|
||||||
|
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
|
||||||
|
]
|
||||||
|
|
||||||
|
attn_outputs = [
|
||||||
|
optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
|
||||||
|
for q, k, v in zip(*splits)
|
||||||
|
]
|
||||||
|
attn_output = torch.cat(attn_outputs, dim=1)
|
||||||
|
attn_output = attn_output.reshape(seq_length, -1)
|
||||||
|
attn_output = self.proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class VisionMLP(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||||
|
self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
|
||||||
|
self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
|
||||||
|
self.act_fn = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, hidden_state):
|
||||||
|
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
||||||
|
|
||||||
|
|
||||||
|
class VisionBlock(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
|
||||||
|
self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
|
||||||
|
self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
cu_seqlens=None,
|
||||||
|
optimized_attention=None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2VLVisionTransformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int = 3584,
|
||||||
|
output_hidden_size: int = 3584,
|
||||||
|
intermediate_size: int = 3420,
|
||||||
|
num_heads: int = 16,
|
||||||
|
num_layers: int = 32,
|
||||||
|
patch_size: int = 14,
|
||||||
|
temporal_patch_size: int = 2,
|
||||||
|
spatial_merge_size: int = 2,
|
||||||
|
window_size: int = 112,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
ops=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.spatial_merge_size = spatial_merge_size
|
||||||
|
self.window_size = window_size
|
||||||
|
self.fullatt_block_indexes = [7, 15, 23, 31]
|
||||||
|
|
||||||
|
self.patch_embed = VisionPatchEmbed(
|
||||||
|
patch_size=patch_size,
|
||||||
|
temporal_patch_size=temporal_patch_size,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dim=hidden_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
ops=ops,
|
||||||
|
)
|
||||||
|
|
||||||
|
head_dim = hidden_size // num_heads
|
||||||
|
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.merger = PatchMerger(
|
||||||
|
dim=output_hidden_size,
|
||||||
|
context_dim=hidden_size,
|
||||||
|
spatial_merge_size=spatial_merge_size,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
ops=ops,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_window_index(self, grid_thw):
|
||||||
|
window_index = []
|
||||||
|
cu_window_seqlens = [0]
|
||||||
|
window_index_id = 0
|
||||||
|
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
|
||||||
|
|
||||||
|
for grid_t, grid_h, grid_w in grid_thw:
|
||||||
|
llm_grid_h = grid_h // self.spatial_merge_size
|
||||||
|
llm_grid_w = grid_w // self.spatial_merge_size
|
||||||
|
|
||||||
|
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
|
||||||
|
|
||||||
|
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||||
|
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||||
|
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||||
|
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||||
|
|
||||||
|
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
||||||
|
index_padded = index_padded.reshape(
|
||||||
|
grid_t,
|
||||||
|
num_windows_h,
|
||||||
|
vit_merger_window_size,
|
||||||
|
num_windows_w,
|
||||||
|
vit_merger_window_size,
|
||||||
|
)
|
||||||
|
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||||
|
grid_t,
|
||||||
|
num_windows_h * num_windows_w,
|
||||||
|
vit_merger_window_size,
|
||||||
|
vit_merger_window_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||||
|
index_padded = index_padded.reshape(-1)
|
||||||
|
index_new = index_padded[index_padded != -100]
|
||||||
|
window_index.append(index_new + window_index_id)
|
||||||
|
|
||||||
|
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1]
|
||||||
|
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||||
|
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||||
|
|
||||||
|
window_index = torch.cat(window_index, dim=0)
|
||||||
|
return window_index, cu_window_seqlens
|
||||||
|
|
||||||
|
def get_position_embeddings(self, grid_thw, device):
|
||||||
|
pos_ids = []
|
||||||
|
|
||||||
|
for t, h, w in grid_thw:
|
||||||
|
hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w)
|
||||||
|
hpos_ids = hpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten()
|
||||||
|
|
||||||
|
wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1)
|
||||||
|
wpos_ids = wpos_ids.reshape(
|
||||||
|
h // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
w // self.spatial_merge_size,
|
||||||
|
self.spatial_merge_size,
|
||||||
|
)
|
||||||
|
wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten()
|
||||||
|
|
||||||
|
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||||
|
|
||||||
|
pos_ids = torch.cat(pos_ids, dim=0)
|
||||||
|
max_grid_size = grid_thw[:, 1:].max()
|
||||||
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device)
|
||||||
|
return rotary_pos_emb_full[pos_ids].flatten(1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
image_grid_thw: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True)
|
||||||
|
|
||||||
|
hidden_states = self.patch_embed(pixel_values)
|
||||||
|
|
||||||
|
window_index, cu_window_seqlens = self.get_window_index(image_grid_thw)
|
||||||
|
cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device)
|
||||||
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||||
|
|
||||||
|
position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device)
|
||||||
|
|
||||||
|
seq_len, _ = hidden_states.size()
|
||||||
|
spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
||||||
|
|
||||||
|
hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
|
||||||
|
hidden_states = hidden_states[window_index, :, :]
|
||||||
|
hidden_states = hidden_states.reshape(seq_len, -1)
|
||||||
|
|
||||||
|
position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
|
||||||
|
position_embeddings = position_embeddings[window_index, :, :]
|
||||||
|
position_embeddings = position_embeddings.reshape(seq_len, -1)
|
||||||
|
position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1)
|
||||||
|
position_embeddings = (position_embeddings.cos(), position_embeddings.sin())
|
||||||
|
|
||||||
|
cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum(
|
||||||
|
dim=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
if i in self.fullatt_block_indexes:
|
||||||
|
cu_seqlens_now = cu_seqlens
|
||||||
|
else:
|
||||||
|
cu_seqlens_now = cu_window_seqlens
|
||||||
|
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
|
||||||
|
|
||||||
|
hidden_states = self.merger(hidden_states)
|
||||||
|
return hidden_states
|
||||||
@ -210,7 +210,7 @@ class T5Stack(torch.nn.Module):
|
|||||||
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
|
||||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||||
|
|
||||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
|
||||||
mask = None
|
mask = None
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||||
|
|||||||
@ -726,6 +726,10 @@ class SEGS(ComfyTypeIO):
|
|||||||
class AnyType(ComfyTypeIO):
|
class AnyType(ComfyTypeIO):
|
||||||
Type = Any
|
Type = Any
|
||||||
|
|
||||||
|
@comfytype(io_type="MODEL_PATCH")
|
||||||
|
class MODEL_PATCH(ComfyTypeIO):
|
||||||
|
Type = Any
|
||||||
|
|
||||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||||
class MultiType:
|
class MultiType:
|
||||||
Type = Any
|
Type = Any
|
||||||
|
|||||||
@ -10,6 +10,11 @@ from typing import Type
|
|||||||
import av
|
import av
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
try:
|
||||||
|
import torchaudio
|
||||||
|
TORCH_AUDIO_AVAILABLE = True
|
||||||
|
except:
|
||||||
|
TORCH_AUDIO_AVAILABLE = False
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import aiohttp
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
@ -21,7 +22,6 @@ from comfy.cmd.server import PromptServer
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import requests
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import base64
|
import base64
|
||||||
@ -30,7 +30,7 @@ from io import BytesIO
|
|||||||
import av
|
import av
|
||||||
|
|
||||||
|
|
||||||
def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
|
async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
|
||||||
"""Downloads a video from a URL and returns a `VIDEO` output.
|
"""Downloads a video from a URL and returns a `VIDEO` output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -39,7 +39,7 @@ def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFr
|
|||||||
Returns:
|
Returns:
|
||||||
A Comfy node `VIDEO` output.
|
A Comfy node `VIDEO` output.
|
||||||
"""
|
"""
|
||||||
video_io = download_url_to_bytesio(video_url, timeout)
|
video_io = await download_url_to_bytesio(video_url, timeout)
|
||||||
if video_io is None:
|
if video_io is None:
|
||||||
error_msg = f"Failed to download video from {video_url}"
|
error_msg = f"Failed to download video from {video_url}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
@ -62,7 +62,7 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def validate_and_cast_response(
|
async def validate_and_cast_response(
|
||||||
response, timeout: int = None, node_id: Union[str, None] = None
|
response, timeout: int = None, node_id: Union[str, None] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Validates and casts a response to a torch.Tensor.
|
"""Validates and casts a response to a torch.Tensor.
|
||||||
@ -86,35 +86,24 @@ def validate_and_cast_response(
|
|||||||
image_tensors: list[torch.Tensor] = []
|
image_tensors: list[torch.Tensor] = []
|
||||||
|
|
||||||
# Process each image in the data array
|
# Process each image in the data array
|
||||||
for image_data in data:
|
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
||||||
image_url = image_data.url
|
for img_data in data:
|
||||||
b64_data = image_data.b64_json
|
img_bytes: bytes
|
||||||
|
if img_data.b64_json:
|
||||||
|
img_bytes = base64.b64decode(img_data.b64_json)
|
||||||
|
elif img_data.url:
|
||||||
|
if node_id:
|
||||||
|
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
|
||||||
|
async with session.get(img_data.url) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise ValueError("Failed to download generated image")
|
||||||
|
img_bytes = await resp.read()
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid image payload – neither URL nor base64 data present.")
|
||||||
|
|
||||||
if not image_url and not b64_data:
|
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
|
||||||
raise ValueError("No image was generated in the response")
|
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||||||
|
image_tensors.append(torch.from_numpy(arr))
|
||||||
if b64_data:
|
|
||||||
img_data = base64.b64decode(b64_data)
|
|
||||||
img = Image.open(io.BytesIO(img_data))
|
|
||||||
|
|
||||||
elif image_url:
|
|
||||||
if node_id:
|
|
||||||
PromptServer.instance.send_progress_text(
|
|
||||||
f"Result URL: {image_url}", node_id
|
|
||||||
)
|
|
||||||
img_response = requests.get(image_url, timeout=timeout)
|
|
||||||
if img_response.status_code != 200:
|
|
||||||
raise ValueError("Failed to download the image")
|
|
||||||
img = Image.open(io.BytesIO(img_response.content))
|
|
||||||
|
|
||||||
img = img.convert("RGBA")
|
|
||||||
|
|
||||||
# Convert to numpy array, normalize to float32 between 0 and 1
|
|
||||||
img_array = np.array(img).astype(np.float32) / 255.0
|
|
||||||
img_tensor = torch.from_numpy(img_array)
|
|
||||||
|
|
||||||
# Add to list of tensors
|
|
||||||
image_tensors.append(img_tensor)
|
|
||||||
|
|
||||||
return torch.stack(image_tensors, dim=0)
|
return torch.stack(image_tensors, dim=0)
|
||||||
|
|
||||||
@ -175,7 +164,7 @@ def mimetype_to_extension(mime_type: str) -> str:
|
|||||||
return mime_type.split("/")[-1].lower()
|
return mime_type.split("/")[-1].lower()
|
||||||
|
|
||||||
|
|
||||||
def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
||||||
"""Downloads content from a URL using requests and returns it as BytesIO.
|
"""Downloads content from a URL using requests and returns it as BytesIO.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -185,9 +174,11 @@ def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
|
|||||||
Returns:
|
Returns:
|
||||||
BytesIO object containing the downloaded content.
|
BytesIO object containing the downloaded content.
|
||||||
"""
|
"""
|
||||||
response = requests.get(url, stream=True, timeout=timeout)
|
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
|
||||||
response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
||||||
return BytesIO(response.content)
|
async with session.get(url) as resp:
|
||||||
|
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||||
|
return BytesIO(await resp.read())
|
||||||
|
|
||||||
|
|
||||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||||
@ -210,15 +201,15 @@ def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch
|
|||||||
return torch.from_numpy(image_array).unsqueeze(0)
|
return torch.from_numpy(image_array).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||||
image_bytesio = download_url_to_bytesio(url, timeout)
|
image_bytesio = await download_url_to_bytesio(url, timeout)
|
||||||
return bytesio_to_image_tensor(image_bytesio)
|
return bytesio_to_image_tensor(image_bytesio)
|
||||||
|
|
||||||
|
|
||||||
def process_image_response(response: requests.Response) -> torch.Tensor:
|
def process_image_response(response_content: bytes | str) -> torch.Tensor:
|
||||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||||
return bytesio_to_image_tensor(BytesIO(response.content))
|
return bytesio_to_image_tensor(BytesIO(response_content))
|
||||||
|
|
||||||
|
|
||||||
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||||
@ -336,10 +327,10 @@ def text_filepath_to_data_uri(filepath: str) -> str:
|
|||||||
return f"data:{mime_type};base64,{base64_string}"
|
return f"data:{mime_type};base64,{base64_string}"
|
||||||
|
|
||||||
|
|
||||||
def upload_file_to_comfyapi(
|
async def upload_file_to_comfyapi(
|
||||||
file_bytes_io: BytesIO,
|
file_bytes_io: BytesIO,
|
||||||
filename: str,
|
filename: str,
|
||||||
upload_mime_type: str,
|
upload_mime_type: Optional[str],
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@ -354,7 +345,10 @@ def upload_file_to_comfyapi(
|
|||||||
Returns:
|
Returns:
|
||||||
The download URL for the uploaded file.
|
The download URL for the uploaded file.
|
||||||
"""
|
"""
|
||||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
if upload_mime_type is None:
|
||||||
|
request_object = UploadRequest(file_name=filename)
|
||||||
|
else:
|
||||||
|
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||||
operation = SynchronousOperation(
|
operation = SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path="/customers/storage",
|
path="/customers/storage",
|
||||||
@ -366,12 +360,8 @@ def upload_file_to_comfyapi(
|
|||||||
auth_kwargs=auth_kwargs,
|
auth_kwargs=auth_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
response: UploadResponse = operation.execute()
|
response: UploadResponse = await operation.execute()
|
||||||
upload_response = ApiClient.upload_file(
|
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
|
||||||
response.upload_url, file_bytes_io, content_type=upload_mime_type
|
|
||||||
)
|
|
||||||
upload_response.raise_for_status()
|
|
||||||
|
|
||||||
return response.download_url
|
return response.download_url
|
||||||
|
|
||||||
|
|
||||||
@ -399,7 +389,7 @@ def video_to_base64_string(
|
|||||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def upload_video_to_comfyapi(
|
async def upload_video_to_comfyapi(
|
||||||
video: VideoInput,
|
video: VideoInput,
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
container: VideoContainer = VideoContainer.MP4,
|
container: VideoContainer = VideoContainer.MP4,
|
||||||
@ -439,9 +429,7 @@ def upload_video_to_comfyapi(
|
|||||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||||
video_bytes_io.seek(0)
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
return upload_file_to_comfyapi(
|
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
|
||||||
video_bytes_io, filename, upload_mime_type, auth_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||||
@ -501,7 +489,7 @@ def audio_ndarray_to_bytesio(
|
|||||||
return audio_bytes_io
|
return audio_bytes_io
|
||||||
|
|
||||||
|
|
||||||
def upload_audio_to_comfyapi(
|
async def upload_audio_to_comfyapi(
|
||||||
audio: AudioInput,
|
audio: AudioInput,
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
container_format: str = "mp4",
|
container_format: str = "mp4",
|
||||||
@ -527,7 +515,7 @@ def upload_audio_to_comfyapi(
|
|||||||
audio_data_np, sample_rate, container_format, codec_name
|
audio_data_np, sample_rate, container_format, codec_name
|
||||||
)
|
)
|
||||||
|
|
||||||
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def audio_to_base64_string(
|
def audio_to_base64_string(
|
||||||
@ -544,7 +532,7 @@ def audio_to_base64_string(
|
|||||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
def upload_images_to_comfyapi(
|
async def upload_images_to_comfyapi(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
max_images=8,
|
max_images=8,
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
@ -561,55 +549,15 @@ def upload_images_to_comfyapi(
|
|||||||
mime_type: Optional MIME type for the image.
|
mime_type: Optional MIME type for the image.
|
||||||
"""
|
"""
|
||||||
# if batch, try to upload each file if max_images is greater than 0
|
# if batch, try to upload each file if max_images is greater than 0
|
||||||
idx_image = 0
|
|
||||||
download_urls: list[str] = []
|
download_urls: list[str] = []
|
||||||
is_batch = len(image.shape) > 3
|
is_batch = len(image.shape) > 3
|
||||||
batch_length = 1
|
batch_len = image.shape[0] if is_batch else 1
|
||||||
if is_batch:
|
|
||||||
batch_length = image.shape[0]
|
|
||||||
while True:
|
|
||||||
curr_image = image
|
|
||||||
if len(image.shape) > 3:
|
|
||||||
curr_image = image[idx_image]
|
|
||||||
# get BytesIO version of image
|
|
||||||
img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type)
|
|
||||||
# first, request upload/download urls from comfy API
|
|
||||||
if not mime_type:
|
|
||||||
request_object = UploadRequest(file_name=img_binary.name)
|
|
||||||
else:
|
|
||||||
request_object = UploadRequest(
|
|
||||||
file_name=img_binary.name, content_type=mime_type
|
|
||||||
)
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/customers/storage",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=UploadRequest,
|
|
||||||
response_model=UploadResponse,
|
|
||||||
),
|
|
||||||
request=request_object,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
response = operation.execute()
|
|
||||||
|
|
||||||
upload_response = ApiClient.upload_file(
|
for idx in range(min(batch_len, max_images)):
|
||||||
response.upload_url, img_binary, content_type=mime_type
|
tensor = image[idx] if is_batch else image
|
||||||
)
|
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||||
# verify success
|
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
|
||||||
try:
|
download_urls.append(url)
|
||||||
upload_response.raise_for_status()
|
|
||||||
except requests.exceptions.HTTPError as e:
|
|
||||||
raise ValueError(f"Could not upload one or more images: {e}") from e
|
|
||||||
# add download_url to list
|
|
||||||
download_urls.append(response.download_url)
|
|
||||||
|
|
||||||
idx_image += 1
|
|
||||||
# stop uploading additional files if done
|
|
||||||
if is_batch and max_images > 0:
|
|
||||||
if idx_image >= max_images:
|
|
||||||
break
|
|
||||||
if idx_image >= batch_length:
|
|
||||||
break
|
|
||||||
return download_urls
|
return download_urls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
16
comfy_api_nodes/apis/__init__.py
generated
16
comfy_api_nodes/apis/__init__.py
generated
@ -1315,6 +1315,7 @@ class KlingTaskStatus(str, Enum):
|
|||||||
class KlingTextToVideoModelName(str, Enum):
|
class KlingTextToVideoModelName(str, Enum):
|
||||||
kling_v1 = 'kling-v1'
|
kling_v1 = 'kling-v1'
|
||||||
kling_v1_6 = 'kling-v1-6'
|
kling_v1_6 = 'kling-v1-6'
|
||||||
|
kling_v2_1_master = 'kling-v2-1-master'
|
||||||
|
|
||||||
|
|
||||||
class KlingVideoGenAspectRatio(str, Enum):
|
class KlingVideoGenAspectRatio(str, Enum):
|
||||||
@ -1347,6 +1348,8 @@ class KlingVideoGenModelName(str, Enum):
|
|||||||
kling_v1_5 = 'kling-v1-5'
|
kling_v1_5 = 'kling-v1-5'
|
||||||
kling_v1_6 = 'kling-v1-6'
|
kling_v1_6 = 'kling-v1-6'
|
||||||
kling_v2_master = 'kling-v2-master'
|
kling_v2_master = 'kling-v2-master'
|
||||||
|
kling_v2_1 = 'kling-v2-1'
|
||||||
|
kling_v2_1_master = 'kling-v2-1-master'
|
||||||
|
|
||||||
|
|
||||||
class KlingVideoResult(BaseModel):
|
class KlingVideoResult(BaseModel):
|
||||||
@ -1620,13 +1623,14 @@ class MinimaxTaskResultResponse(BaseModel):
|
|||||||
task_id: str = Field(..., description='The task ID being queried.')
|
task_id: str = Field(..., description='The task ID being queried.')
|
||||||
|
|
||||||
|
|
||||||
class Model(str, Enum):
|
class MiniMaxModel(str, Enum):
|
||||||
T2V_01_Director = 'T2V-01-Director'
|
T2V_01_Director = 'T2V-01-Director'
|
||||||
I2V_01_Director = 'I2V-01-Director'
|
I2V_01_Director = 'I2V-01-Director'
|
||||||
S2V_01 = 'S2V-01'
|
S2V_01 = 'S2V-01'
|
||||||
I2V_01 = 'I2V-01'
|
I2V_01 = 'I2V-01'
|
||||||
I2V_01_live = 'I2V-01-live'
|
I2V_01_live = 'I2V-01-live'
|
||||||
T2V_01 = 'T2V-01'
|
T2V_01 = 'T2V-01'
|
||||||
|
Hailuo_02 = 'MiniMax-Hailuo-02'
|
||||||
|
|
||||||
|
|
||||||
class SubjectReferenceItem(BaseModel):
|
class SubjectReferenceItem(BaseModel):
|
||||||
@ -1648,7 +1652,7 @@ class MinimaxVideoGenerationRequest(BaseModel):
|
|||||||
None,
|
None,
|
||||||
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
|
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
|
||||||
)
|
)
|
||||||
model: Model = Field(
|
model: MiniMaxModel = Field(
|
||||||
...,
|
...,
|
||||||
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
|
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
|
||||||
)
|
)
|
||||||
@ -1665,6 +1669,14 @@ class MinimaxVideoGenerationRequest(BaseModel):
|
|||||||
None,
|
None,
|
||||||
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
|
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
|
||||||
)
|
)
|
||||||
|
duration: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
description="The length of the output video in seconds."
|
||||||
|
)
|
||||||
|
resolution: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MinimaxVideoGenerationResponse(BaseModel):
|
class MinimaxVideoGenerationResponse(BaseModel):
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import io
|
import io
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional
|
||||||
@ -28,7 +29,7 @@ from comfy_api_nodes.apinode_utils import (
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import requests
|
import aiohttp
|
||||||
import torch
|
import torch
|
||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
@ -44,18 +45,18 @@ def convert_mask_to_image(mask: torch.Tensor):
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def handle_bfl_synchronous_operation(
|
async def handle_bfl_synchronous_operation(
|
||||||
operation: SynchronousOperation,
|
operation: SynchronousOperation,
|
||||||
timeout_bfl_calls=360,
|
timeout_bfl_calls=360,
|
||||||
node_id: Union[str, None] = None,
|
node_id: Union[str, None] = None,
|
||||||
):
|
):
|
||||||
response_api: BFLFluxProGenerateResponse = operation.execute()
|
response_api: BFLFluxProGenerateResponse = await operation.execute()
|
||||||
return _poll_until_generated(
|
return await _poll_until_generated(
|
||||||
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
|
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _poll_until_generated(
|
async def _poll_until_generated(
|
||||||
polling_url: str, timeout=360, node_id: Union[str, None] = None
|
polling_url: str, timeout=360, node_id: Union[str, None] = None
|
||||||
):
|
):
|
||||||
# used bfl-comfy-nodes to verify code implementation:
|
# used bfl-comfy-nodes to verify code implementation:
|
||||||
@ -66,55 +67,56 @@ def _poll_until_generated(
|
|||||||
retry_404_seconds = 2
|
retry_404_seconds = 2
|
||||||
retry_202_seconds = 2
|
retry_202_seconds = 2
|
||||||
retry_pending_seconds = 1
|
retry_pending_seconds = 1
|
||||||
request = requests.Request(method=HttpMethod.GET, url=polling_url)
|
|
||||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
|
||||||
while True:
|
|
||||||
if node_id:
|
|
||||||
time_elapsed = time.time() - start_time
|
|
||||||
PromptServer.instance.send_progress_text(
|
|
||||||
f"Generating ({time_elapsed:.0f}s)", node_id
|
|
||||||
)
|
|
||||||
|
|
||||||
response = requests.Session().send(request.prepare())
|
async with aiohttp.ClientSession() as session:
|
||||||
if response.status_code == 200:
|
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||||
result = response.json()
|
while True:
|
||||||
if result["status"] == BFLStatus.ready:
|
if node_id:
|
||||||
img_url = result["result"]["sample"]
|
time_elapsed = time.time() - start_time
|
||||||
if node_id:
|
PromptServer.instance.send_progress_text(
|
||||||
PromptServer.instance.send_progress_text(
|
f"Generating ({time_elapsed:.0f}s)", node_id
|
||||||
f"Result URL: {img_url}", node_id
|
|
||||||
)
|
|
||||||
img_response = requests.get(img_url)
|
|
||||||
return process_image_response(img_response)
|
|
||||||
elif result["status"] in [
|
|
||||||
BFLStatus.request_moderated,
|
|
||||||
BFLStatus.content_moderated,
|
|
||||||
]:
|
|
||||||
status = result["status"]
|
|
||||||
raise Exception(
|
|
||||||
f"BFL API did not return an image due to: {status}."
|
|
||||||
)
|
)
|
||||||
elif result["status"] == BFLStatus.error:
|
|
||||||
raise Exception(f"BFL API encountered an error: {result}.")
|
async with session.get(polling_url) as response:
|
||||||
elif result["status"] == BFLStatus.pending:
|
if response.status == 200:
|
||||||
time.sleep(retry_pending_seconds)
|
result = await response.json()
|
||||||
continue
|
if result["status"] == BFLStatus.ready:
|
||||||
elif response.status_code == 404:
|
img_url = result["result"]["sample"]
|
||||||
if retries_404 < max_retries_404:
|
if node_id:
|
||||||
retries_404 += 1
|
PromptServer.instance.send_progress_text(
|
||||||
time.sleep(retry_404_seconds)
|
f"Result URL: {img_url}", node_id
|
||||||
continue
|
)
|
||||||
raise Exception(
|
async with session.get(img_url) as img_resp:
|
||||||
f"BFL API could not find task after {max_retries_404} tries."
|
return process_image_response(await img_resp.content.read())
|
||||||
)
|
elif result["status"] in [
|
||||||
elif response.status_code == 202:
|
BFLStatus.request_moderated,
|
||||||
time.sleep(retry_202_seconds)
|
BFLStatus.content_moderated,
|
||||||
elif time.time() - start_time > timeout:
|
]:
|
||||||
raise Exception(
|
status = result["status"]
|
||||||
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
raise Exception(
|
||||||
)
|
f"BFL API did not return an image due to: {status}."
|
||||||
else:
|
)
|
||||||
raise Exception(f"BFL API encountered an error: {response.json()}")
|
elif result["status"] == BFLStatus.error:
|
||||||
|
raise Exception(f"BFL API encountered an error: {result}.")
|
||||||
|
elif result["status"] == BFLStatus.pending:
|
||||||
|
await asyncio.sleep(retry_pending_seconds)
|
||||||
|
continue
|
||||||
|
elif response.status == 404:
|
||||||
|
if retries_404 < max_retries_404:
|
||||||
|
retries_404 += 1
|
||||||
|
await asyncio.sleep(retry_404_seconds)
|
||||||
|
continue
|
||||||
|
raise Exception(
|
||||||
|
f"BFL API could not find task after {max_retries_404} tries."
|
||||||
|
)
|
||||||
|
elif response.status == 202:
|
||||||
|
await asyncio.sleep(retry_202_seconds)
|
||||||
|
elif time.time() - start_time > timeout:
|
||||||
|
raise Exception(
|
||||||
|
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(f"BFL API encountered an error: {response.json()}")
|
||||||
|
|
||||||
def convert_image_to_base64(image: torch.Tensor):
|
def convert_image_to_base64(image: torch.Tensor):
|
||||||
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
|
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
|
||||||
@ -222,7 +224,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
|||||||
API_NODE = True
|
API_NODE = True
|
||||||
CATEGORY = "api node/image/BFL"
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
@ -266,7 +268,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||||
return (output_image,)
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
@ -354,7 +356,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
|||||||
|
|
||||||
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
|
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
@ -397,7 +399,7 @@ class FluxKontextProImageNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||||
return (output_image,)
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
@ -489,7 +491,7 @@ class FluxProImageNode(ComfyNodeABC):
|
|||||||
API_NODE = True
|
API_NODE = True
|
||||||
CATEGORY = "api node/image/BFL"
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
prompt_upsampling,
|
prompt_upsampling,
|
||||||
@ -524,7 +526,7 @@ class FluxProImageNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||||
return (output_image,)
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
@ -632,7 +634,7 @@ class FluxProExpandNode(ComfyNodeABC):
|
|||||||
API_NODE = True
|
API_NODE = True
|
||||||
CATEGORY = "api node/image/BFL"
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -670,7 +672,7 @@ class FluxProExpandNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||||
return (output_image,)
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
@ -744,7 +746,7 @@ class FluxProFillNode(ComfyNodeABC):
|
|||||||
API_NODE = True
|
API_NODE = True
|
||||||
CATEGORY = "api node/image/BFL"
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
mask: torch.Tensor,
|
mask: torch.Tensor,
|
||||||
@ -780,7 +782,7 @@ class FluxProFillNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||||
return (output_image,)
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
@ -879,7 +881,7 @@ class FluxProCannyNode(ComfyNodeABC):
|
|||||||
API_NODE = True
|
API_NODE = True
|
||||||
CATEGORY = "api node/image/BFL"
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
control_image: torch.Tensor,
|
control_image: torch.Tensor,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -929,7 +931,7 @@ class FluxProCannyNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||||
return (output_image,)
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
@ -1008,7 +1010,7 @@ class FluxProDepthNode(ComfyNodeABC):
|
|||||||
API_NODE = True
|
API_NODE = True
|
||||||
CATEGORY = "api node/image/BFL"
|
CATEGORY = "api node/image/BFL"
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
control_image: torch.Tensor,
|
control_image: torch.Tensor,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -1045,7 +1047,7 @@ class FluxProDepthNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id)
|
||||||
return (output_image,)
|
return (output_image,)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,10 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Literal
|
from typing import Optional, Literal
|
||||||
|
|
||||||
@ -46,6 +49,8 @@ class GeminiModel(str, Enum):
|
|||||||
|
|
||||||
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
|
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
|
||||||
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
||||||
|
gemini_2_5_pro = "gemini-2.5-pro"
|
||||||
|
gemini_2_5_flash = "gemini-2.5-flash"
|
||||||
|
|
||||||
|
|
||||||
def get_gemini_endpoint(
|
def get_gemini_endpoint(
|
||||||
@ -97,7 +102,7 @@ class GeminiNode(ComfyNodeABC):
|
|||||||
{
|
{
|
||||||
"tooltip": "The Gemini model to use for generating responses.",
|
"tooltip": "The Gemini model to use for generating responses.",
|
||||||
"options": [model.value for model in GeminiModel],
|
"options": [model.value for model in GeminiModel],
|
||||||
"default": GeminiModel.gemini_2_5_pro_preview_05_06.value,
|
"default": GeminiModel.gemini_2_5_pro.value,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
"seed": (
|
"seed": (
|
||||||
@ -303,7 +308,7 @@ class GeminiNode(ComfyNodeABC):
|
|||||||
"""
|
"""
|
||||||
return GeminiPart(text=text)
|
return GeminiPart(text=text)
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: GeminiModel,
|
model: GeminiModel,
|
||||||
@ -332,7 +337,7 @@ class GeminiNode(ComfyNodeABC):
|
|||||||
parts.extend(files)
|
parts.extend(files)
|
||||||
|
|
||||||
# Create response
|
# Create response
|
||||||
response = SynchronousOperation(
|
response = await SynchronousOperation(
|
||||||
endpoint=get_gemini_endpoint(model),
|
endpoint=get_gemini_endpoint(model),
|
||||||
request=GeminiGenerateContentRequest(
|
request=GeminiGenerateContentRequest(
|
||||||
contents=[
|
contents=[
|
||||||
@ -348,7 +353,27 @@ class GeminiNode(ComfyNodeABC):
|
|||||||
# Get result output
|
# Get result output
|
||||||
output_text = self.get_text_from_response(response)
|
output_text = self.get_text_from_response(response)
|
||||||
if unique_id and output_text:
|
if unique_id and output_text:
|
||||||
PromptServer.instance.send_progress_text(output_text, node_id=unique_id)
|
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||||
|
render_spec = {
|
||||||
|
"node_id": unique_id,
|
||||||
|
"component": "ChatHistoryWidget",
|
||||||
|
"props": {
|
||||||
|
"history": json.dumps(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"prompt": prompt,
|
||||||
|
"response": output_text,
|
||||||
|
"response_id": str(uuid.uuid4()),
|
||||||
|
"timestamp": time.time(),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
PromptServer.instance.send_sync(
|
||||||
|
"display_component",
|
||||||
|
render_spec,
|
||||||
|
)
|
||||||
|
|
||||||
return (output_text or "Empty response from Gemini model...",)
|
return (output_text or "Empty response from Gemini model...",)
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
from io import BytesIO
|
||||||
from inspect import cleandoc
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import io
|
|
||||||
import torch
|
import torch
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
IdeogramGenerateRequest,
|
IdeogramGenerateRequest,
|
||||||
@ -212,7 +212,7 @@ V3_RESOLUTIONS= [
|
|||||||
"1536x640"
|
"1536x640"
|
||||||
]
|
]
|
||||||
|
|
||||||
def download_and_process_images(image_urls):
|
async def download_and_process_images(image_urls):
|
||||||
"""Helper function to download and process multiple images from URLs"""
|
"""Helper function to download and process multiple images from URLs"""
|
||||||
|
|
||||||
# Initialize list to store image tensors
|
# Initialize list to store image tensors
|
||||||
@ -220,7 +220,7 @@ def download_and_process_images(image_urls):
|
|||||||
|
|
||||||
for image_url in image_urls:
|
for image_url in image_urls:
|
||||||
# Using functions from apinode_utils.py to handle downloading and processing
|
# Using functions from apinode_utils.py to handle downloading and processing
|
||||||
image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO
|
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||||
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||||
image_tensors.append(img_tensor)
|
image_tensors.append(img_tensor)
|
||||||
|
|
||||||
@ -246,90 +246,81 @@ def display_image_urls_on_node(image_urls, node_id):
|
|||||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
PromptServer.instance.send_progress_text(urls_text, node_id)
|
||||||
|
|
||||||
|
|
||||||
class IdeogramV1(ComfyNodeABC):
|
class IdeogramV1(comfy_io.ComfyNode):
|
||||||
"""
|
|
||||||
Generates images using the Ideogram V1 model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def define_schema(cls):
|
||||||
return {
|
return comfy_io.Schema(
|
||||||
"required": {
|
node_id="IdeogramV1",
|
||||||
"prompt": (
|
display_name="Ideogram V1",
|
||||||
IO.STRING,
|
category="api node/image/Ideogram",
|
||||||
{
|
description="Generates images using the Ideogram V1 model.",
|
||||||
"multiline": True,
|
inputs=[
|
||||||
"default": "",
|
comfy_io.String.Input(
|
||||||
"tooltip": "Prompt for the image generation",
|
"prompt",
|
||||||
},
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Prompt for the image generation",
|
||||||
),
|
),
|
||||||
"turbo": (
|
comfy_io.Boolean.Input(
|
||||||
IO.BOOLEAN,
|
"turbo",
|
||||||
{
|
default=False,
|
||||||
"default": False,
|
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||||
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
|
||||||
}
|
|
||||||
),
|
),
|
||||||
},
|
comfy_io.Combo.Input(
|
||||||
"optional": {
|
"aspect_ratio",
|
||||||
"aspect_ratio": (
|
options=list(V1_V2_RATIO_MAP.keys()),
|
||||||
IO.COMBO,
|
default="1:1",
|
||||||
{
|
tooltip="The aspect ratio for image generation.",
|
||||||
"options": list(V1_V2_RATIO_MAP.keys()),
|
optional=True,
|
||||||
"default": "1:1",
|
|
||||||
"tooltip": "The aspect ratio for image generation.",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"magic_prompt_option": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"magic_prompt_option",
|
||||||
{
|
options=["AUTO", "ON", "OFF"],
|
||||||
"options": ["AUTO", "ON", "OFF"],
|
default="AUTO",
|
||||||
"default": "AUTO",
|
tooltip="Determine if MagicPrompt should be used in generation",
|
||||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"seed": (
|
comfy_io.Int.Input(
|
||||||
IO.INT,
|
"seed",
|
||||||
{
|
default=0,
|
||||||
"default": 0,
|
min=0,
|
||||||
"min": 0,
|
max=2147483647,
|
||||||
"max": 2147483647,
|
step=1,
|
||||||
"step": 1,
|
control_after_generate=True,
|
||||||
"control_after_generate": True,
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
"display": "number",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"negative_prompt": (
|
comfy_io.String.Input(
|
||||||
IO.STRING,
|
"negative_prompt",
|
||||||
{
|
multiline=True,
|
||||||
"multiline": True,
|
default="",
|
||||||
"default": "",
|
tooltip="Description of what to exclude from the image",
|
||||||
"tooltip": "Description of what to exclude from the image",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"num_images": (
|
comfy_io.Int.Input(
|
||||||
IO.INT,
|
"num_images",
|
||||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
default=1,
|
||||||
|
min=1,
|
||||||
|
max=8,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
optional=True,
|
||||||
),
|
),
|
||||||
},
|
],
|
||||||
"hidden": {
|
outputs=[
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
comfy_io.Image.Output(),
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
],
|
||||||
"unique_id": "UNIQUE_ID",
|
hidden=[
|
||||||
},
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
}
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.IMAGE,)
|
@classmethod
|
||||||
FUNCTION = "api_call"
|
async def execute(
|
||||||
CATEGORY = "api node/image/Ideogram"
|
cls,
|
||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
def api_call(
|
|
||||||
self,
|
|
||||||
prompt,
|
prompt,
|
||||||
turbo=False,
|
turbo=False,
|
||||||
aspect_ratio="1:1",
|
aspect_ratio="1:1",
|
||||||
@ -337,13 +328,15 @@ class IdeogramV1(ComfyNodeABC):
|
|||||||
seed=0,
|
seed=0,
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
num_images=1,
|
num_images=1,
|
||||||
unique_id=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
# Determine the model based on turbo setting
|
# Determine the model based on turbo setting
|
||||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||||
model = "V_1_TURBO" if turbo else "V_1"
|
model = "V_1_TURBO" if turbo else "V_1"
|
||||||
|
|
||||||
|
auth = {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
}
|
||||||
operation = SynchronousOperation(
|
operation = SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path="/proxy/ideogram/generate",
|
path="/proxy/ideogram/generate",
|
||||||
@ -364,10 +357,10 @@ class IdeogramV1(ComfyNodeABC):
|
|||||||
negative_prompt=negative_prompt if negative_prompt else None,
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = operation.execute()
|
response = await operation.execute()
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
raise Exception("No images were generated in the response")
|
raise Exception("No images were generated in the response")
|
||||||
@ -377,93 +370,85 @@ class IdeogramV1(ComfyNodeABC):
|
|||||||
if not image_urls:
|
if not image_urls:
|
||||||
raise Exception("No image URLs were generated in the response")
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
|
||||||
display_image_urls_on_node(image_urls, unique_id)
|
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||||
return (download_and_process_images(image_urls),)
|
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||||
|
|
||||||
|
|
||||||
class IdeogramV2(ComfyNodeABC):
|
class IdeogramV2(comfy_io.ComfyNode):
|
||||||
"""
|
|
||||||
Generates images using the Ideogram V2 model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def define_schema(cls):
|
||||||
return {
|
return comfy_io.Schema(
|
||||||
"required": {
|
node_id="IdeogramV2",
|
||||||
"prompt": (
|
display_name="Ideogram V2",
|
||||||
IO.STRING,
|
category="api node/image/Ideogram",
|
||||||
{
|
description="Generates images using the Ideogram V2 model.",
|
||||||
"multiline": True,
|
inputs=[
|
||||||
"default": "",
|
comfy_io.String.Input(
|
||||||
"tooltip": "Prompt for the image generation",
|
"prompt",
|
||||||
},
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Prompt for the image generation",
|
||||||
),
|
),
|
||||||
"turbo": (
|
comfy_io.Boolean.Input(
|
||||||
IO.BOOLEAN,
|
"turbo",
|
||||||
{
|
default=False,
|
||||||
"default": False,
|
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||||
"tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
|
|
||||||
}
|
|
||||||
),
|
),
|
||||||
},
|
comfy_io.Combo.Input(
|
||||||
"optional": {
|
"aspect_ratio",
|
||||||
"aspect_ratio": (
|
options=list(V1_V2_RATIO_MAP.keys()),
|
||||||
IO.COMBO,
|
default="1:1",
|
||||||
{
|
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||||
"options": list(V1_V2_RATIO_MAP.keys()),
|
optional=True,
|
||||||
"default": "1:1",
|
|
||||||
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"resolution": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"resolution",
|
||||||
{
|
options=list(V1_V1_RES_MAP.keys()),
|
||||||
"options": list(V1_V1_RES_MAP.keys()),
|
default="Auto",
|
||||||
"default": "Auto",
|
tooltip="The resolution for image generation. "
|
||||||
"tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.",
|
"If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||||
},
|
optional=True,
|
||||||
),
|
),
|
||||||
"magic_prompt_option": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"magic_prompt_option",
|
||||||
{
|
options=["AUTO", "ON", "OFF"],
|
||||||
"options": ["AUTO", "ON", "OFF"],
|
default="AUTO",
|
||||||
"default": "AUTO",
|
tooltip="Determine if MagicPrompt should be used in generation",
|
||||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"seed": (
|
comfy_io.Int.Input(
|
||||||
IO.INT,
|
"seed",
|
||||||
{
|
default=0,
|
||||||
"default": 0,
|
min=0,
|
||||||
"min": 0,
|
max=2147483647,
|
||||||
"max": 2147483647,
|
step=1,
|
||||||
"step": 1,
|
control_after_generate=True,
|
||||||
"control_after_generate": True,
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
"display": "number",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"style_type": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"style_type",
|
||||||
{
|
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||||
"options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
default="NONE",
|
||||||
"default": "NONE",
|
tooltip="Style type for generation (V2 only)",
|
||||||
"tooltip": "Style type for generation (V2 only)",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"negative_prompt": (
|
comfy_io.String.Input(
|
||||||
IO.STRING,
|
"negative_prompt",
|
||||||
{
|
multiline=True,
|
||||||
"multiline": True,
|
default="",
|
||||||
"default": "",
|
tooltip="Description of what to exclude from the image",
|
||||||
"tooltip": "Description of what to exclude from the image",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"num_images": (
|
comfy_io.Int.Input(
|
||||||
IO.INT,
|
"num_images",
|
||||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
default=1,
|
||||||
|
min=1,
|
||||||
|
max=8,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
optional=True,
|
||||||
),
|
),
|
||||||
#"color_palette": (
|
#"color_palette": (
|
||||||
# IO.STRING,
|
# IO.STRING,
|
||||||
@ -473,22 +458,20 @@ class IdeogramV2(ComfyNodeABC):
|
|||||||
# "tooltip": "Color palette preset name or hex colors with weights",
|
# "tooltip": "Color palette preset name or hex colors with weights",
|
||||||
# },
|
# },
|
||||||
#),
|
#),
|
||||||
},
|
],
|
||||||
"hidden": {
|
outputs=[
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
comfy_io.Image.Output(),
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
],
|
||||||
"unique_id": "UNIQUE_ID",
|
hidden=[
|
||||||
},
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
}
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.IMAGE,)
|
@classmethod
|
||||||
FUNCTION = "api_call"
|
async def execute(
|
||||||
CATEGORY = "api node/image/Ideogram"
|
cls,
|
||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
def api_call(
|
|
||||||
self,
|
|
||||||
prompt,
|
prompt,
|
||||||
turbo=False,
|
turbo=False,
|
||||||
aspect_ratio="1:1",
|
aspect_ratio="1:1",
|
||||||
@ -499,8 +482,6 @@ class IdeogramV2(ComfyNodeABC):
|
|||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
num_images=1,
|
num_images=1,
|
||||||
color_palette="",
|
color_palette="",
|
||||||
unique_id=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||||
resolution = V1_V1_RES_MAP.get(resolution, None)
|
resolution = V1_V1_RES_MAP.get(resolution, None)
|
||||||
@ -517,6 +498,10 @@ class IdeogramV2(ComfyNodeABC):
|
|||||||
else:
|
else:
|
||||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||||
|
|
||||||
|
auth = {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
}
|
||||||
operation = SynchronousOperation(
|
operation = SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path="/proxy/ideogram/generate",
|
path="/proxy/ideogram/generate",
|
||||||
@ -540,10 +525,10 @@ class IdeogramV2(ComfyNodeABC):
|
|||||||
color_palette=color_palette if color_palette else None,
|
color_palette=color_palette if color_palette else None,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = operation.execute()
|
response = await operation.execute()
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
raise Exception("No images were generated in the response")
|
raise Exception("No images were generated in the response")
|
||||||
@ -553,108 +538,99 @@ class IdeogramV2(ComfyNodeABC):
|
|||||||
if not image_urls:
|
if not image_urls:
|
||||||
raise Exception("No image URLs were generated in the response")
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
|
||||||
display_image_urls_on_node(image_urls, unique_id)
|
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||||
return (download_and_process_images(image_urls),)
|
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||||
|
|
||||||
class IdeogramV3(ComfyNodeABC):
|
|
||||||
"""
|
|
||||||
Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
class IdeogramV3(comfy_io.ComfyNode):
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def define_schema(cls):
|
||||||
return {
|
return comfy_io.Schema(
|
||||||
"required": {
|
node_id="IdeogramV3",
|
||||||
"prompt": (
|
display_name="Ideogram V3",
|
||||||
IO.STRING,
|
category="api node/image/Ideogram",
|
||||||
{
|
description="Generates images using the Ideogram V3 model. "
|
||||||
"multiline": True,
|
"Supports both regular image generation from text prompts and image editing with mask.",
|
||||||
"default": "",
|
inputs=[
|
||||||
"tooltip": "Prompt for the image generation or editing",
|
comfy_io.String.Input(
|
||||||
},
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Prompt for the image generation or editing",
|
||||||
),
|
),
|
||||||
},
|
comfy_io.Image.Input(
|
||||||
"optional": {
|
"image",
|
||||||
"image": (
|
tooltip="Optional reference image for image editing.",
|
||||||
IO.IMAGE,
|
optional=True,
|
||||||
{
|
|
||||||
"default": None,
|
|
||||||
"tooltip": "Optional reference image for image editing.",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"mask": (
|
comfy_io.Mask.Input(
|
||||||
IO.MASK,
|
"mask",
|
||||||
{
|
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
||||||
"default": None,
|
optional=True,
|
||||||
"tooltip": "Optional mask for inpainting (white areas will be replaced)",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"aspect_ratio": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"aspect_ratio",
|
||||||
{
|
options=list(V3_RATIO_MAP.keys()),
|
||||||
"options": list(V3_RATIO_MAP.keys()),
|
default="1:1",
|
||||||
"default": "1:1",
|
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||||
"tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"resolution": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"resolution",
|
||||||
{
|
options=V3_RESOLUTIONS,
|
||||||
"options": V3_RESOLUTIONS,
|
default="Auto",
|
||||||
"default": "Auto",
|
tooltip="The resolution for image generation. "
|
||||||
"tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.",
|
"If not set to Auto, this overrides the aspect_ratio setting.",
|
||||||
},
|
optional=True,
|
||||||
),
|
),
|
||||||
"magic_prompt_option": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"magic_prompt_option",
|
||||||
{
|
options=["AUTO", "ON", "OFF"],
|
||||||
"options": ["AUTO", "ON", "OFF"],
|
default="AUTO",
|
||||||
"default": "AUTO",
|
tooltip="Determine if MagicPrompt should be used in generation",
|
||||||
"tooltip": "Determine if MagicPrompt should be used in generation",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"seed": (
|
comfy_io.Int.Input(
|
||||||
IO.INT,
|
"seed",
|
||||||
{
|
default=0,
|
||||||
"default": 0,
|
min=0,
|
||||||
"min": 0,
|
max=2147483647,
|
||||||
"max": 2147483647,
|
step=1,
|
||||||
"step": 1,
|
control_after_generate=True,
|
||||||
"control_after_generate": True,
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
"display": "number",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"num_images": (
|
comfy_io.Int.Input(
|
||||||
IO.INT,
|
"num_images",
|
||||||
{"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
|
default=1,
|
||||||
|
min=1,
|
||||||
|
max=8,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
optional=True,
|
||||||
),
|
),
|
||||||
"rendering_speed": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"rendering_speed",
|
||||||
{
|
options=["BALANCED", "TURBO", "QUALITY"],
|
||||||
"options": ["BALANCED", "TURBO", "QUALITY"],
|
default="BALANCED",
|
||||||
"default": "BALANCED",
|
tooltip="Controls the trade-off between generation speed and quality",
|
||||||
"tooltip": "Controls the trade-off between generation speed and quality",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
},
|
],
|
||||||
"hidden": {
|
outputs=[
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
comfy_io.Image.Output(),
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
],
|
||||||
"unique_id": "UNIQUE_ID",
|
hidden=[
|
||||||
},
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
}
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.IMAGE,)
|
@classmethod
|
||||||
FUNCTION = "api_call"
|
async def execute(
|
||||||
CATEGORY = "api node/image/Ideogram"
|
cls,
|
||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
def api_call(
|
|
||||||
self,
|
|
||||||
prompt,
|
prompt,
|
||||||
image=None,
|
image=None,
|
||||||
mask=None,
|
mask=None,
|
||||||
@ -664,9 +640,11 @@ class IdeogramV3(ComfyNodeABC):
|
|||||||
seed=0,
|
seed=0,
|
||||||
num_images=1,
|
num_images=1,
|
||||||
rendering_speed="BALANCED",
|
rendering_speed="BALANCED",
|
||||||
unique_id=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
|
auth = {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
}
|
||||||
# Check if both image and mask are provided for editing mode
|
# Check if both image and mask are provided for editing mode
|
||||||
if image is not None and mask is not None:
|
if image is not None and mask is not None:
|
||||||
# Edit mode
|
# Edit mode
|
||||||
@ -686,7 +664,7 @@ class IdeogramV3(ComfyNodeABC):
|
|||||||
# Process image
|
# Process image
|
||||||
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||||
img = Image.fromarray(img_np)
|
img = Image.fromarray(img_np)
|
||||||
img_byte_arr = io.BytesIO()
|
img_byte_arr = BytesIO()
|
||||||
img.save(img_byte_arr, format="PNG")
|
img.save(img_byte_arr, format="PNG")
|
||||||
img_byte_arr.seek(0)
|
img_byte_arr.seek(0)
|
||||||
img_binary = img_byte_arr
|
img_binary = img_byte_arr
|
||||||
@ -695,7 +673,7 @@ class IdeogramV3(ComfyNodeABC):
|
|||||||
# Process mask - white areas will be replaced
|
# Process mask - white areas will be replaced
|
||||||
mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
|
mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
|
||||||
mask_img = Image.fromarray(mask_np)
|
mask_img = Image.fromarray(mask_np)
|
||||||
mask_byte_arr = io.BytesIO()
|
mask_byte_arr = BytesIO()
|
||||||
mask_img.save(mask_byte_arr, format="PNG")
|
mask_img.save(mask_byte_arr, format="PNG")
|
||||||
mask_byte_arr.seek(0)
|
mask_byte_arr.seek(0)
|
||||||
mask_binary = mask_byte_arr
|
mask_binary = mask_byte_arr
|
||||||
@ -729,7 +707,7 @@ class IdeogramV3(ComfyNodeABC):
|
|||||||
"mask": mask_binary,
|
"mask": mask_binary,
|
||||||
},
|
},
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif image is not None or mask is not None:
|
elif image is not None or mask is not None:
|
||||||
@ -770,11 +748,11 @@ class IdeogramV3(ComfyNodeABC):
|
|||||||
response_model=IdeogramGenerateResponse,
|
response_model=IdeogramGenerateResponse,
|
||||||
),
|
),
|
||||||
request=gen_request,
|
request=gen_request,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the operation and process response
|
# Execute the operation and process response
|
||||||
response = operation.execute()
|
response = await operation.execute()
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
raise Exception("No images were generated in the response")
|
raise Exception("No images were generated in the response")
|
||||||
@ -784,18 +762,18 @@ class IdeogramV3(ComfyNodeABC):
|
|||||||
if not image_urls:
|
if not image_urls:
|
||||||
raise Exception("No image URLs were generated in the response")
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
|
||||||
display_image_urls_on_node(image_urls, unique_id)
|
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||||
return (download_and_process_images(image_urls),)
|
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class IdeogramExtension(ComfyExtension):
|
||||||
"IdeogramV1": IdeogramV1,
|
@override
|
||||||
"IdeogramV2": IdeogramV2,
|
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||||
"IdeogramV3": IdeogramV3,
|
return [
|
||||||
}
|
IdeogramV1,
|
||||||
|
IdeogramV2,
|
||||||
|
IdeogramV3,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
async def comfy_entrypoint() -> IdeogramExtension:
|
||||||
"IdeogramV1": "Ideogram V1",
|
return IdeogramExtension()
|
||||||
"IdeogramV2": "Ideogram V2",
|
|
||||||
"IdeogramV3": "Ideogram V3",
|
|
||||||
}
|
|
||||||
|
|||||||
@ -109,7 +109,7 @@ class KlingApiError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def poll_until_finished(
|
async def poll_until_finished(
|
||||||
auth_kwargs: dict[str, str],
|
auth_kwargs: dict[str, str],
|
||||||
api_endpoint: ApiEndpoint[Any, R],
|
api_endpoint: ApiEndpoint[Any, R],
|
||||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||||
@ -117,7 +117,7 @@ def poll_until_finished(
|
|||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
) -> R:
|
) -> R:
|
||||||
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
|
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
|
||||||
return PollingOperation(
|
return await PollingOperation(
|
||||||
poll_endpoint=api_endpoint,
|
poll_endpoint=api_endpoint,
|
||||||
completed_statuses=[
|
completed_statuses=[
|
||||||
KlingTaskStatus.succeed.value,
|
KlingTaskStatus.succeed.value,
|
||||||
@ -278,18 +278,18 @@ def get_images_urls_from_response(response) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def video_result_to_node_output(
|
async def video_result_to_node_output(
|
||||||
video: KlingVideoResult,
|
video: KlingVideoResult,
|
||||||
) -> tuple[VideoFromFile, str, str]:
|
) -> tuple[VideoFromFile, str, str]:
|
||||||
"""Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output."""
|
"""Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output."""
|
||||||
return (
|
return (
|
||||||
download_url_to_video_output(video.url),
|
await download_url_to_video_output(str(video.url)),
|
||||||
str(video.id),
|
str(video.id),
|
||||||
str(video.duration),
|
str(video.duration),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def image_result_to_node_output(
|
async def image_result_to_node_output(
|
||||||
images: list[KlingImageResult],
|
images: list[KlingImageResult],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -297,9 +297,9 @@ def image_result_to_node_output(
|
|||||||
If multiple images are returned, they will be stacked along the batch dimension.
|
If multiple images are returned, they will be stacked along the batch dimension.
|
||||||
"""
|
"""
|
||||||
if len(images) == 1:
|
if len(images) == 1:
|
||||||
return download_url_to_image_tensor(images[0].url)
|
return await download_url_to_image_tensor(str(images[0].url))
|
||||||
else:
|
else:
|
||||||
return torch.cat([download_url_to_image_tensor(image.url) for image in images])
|
return torch.cat([await download_url_to_image_tensor(str(image.url)) for image in images])
|
||||||
|
|
||||||
|
|
||||||
class KlingNodeBase(ComfyNodeABC):
|
class KlingNodeBase(ComfyNodeABC):
|
||||||
@ -421,6 +421,8 @@ class KlingTextToVideoNode(KlingNodeBase):
|
|||||||
"pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"),
|
"pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"),
|
||||||
"standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"),
|
"standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"),
|
||||||
"standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"),
|
"standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"),
|
||||||
|
"pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"),
|
||||||
|
"pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -467,10 +469,10 @@ class KlingTextToVideoNode(KlingNodeBase):
|
|||||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||||
DESCRIPTION = "Kling Text to Video Node"
|
DESCRIPTION = "Kling Text to Video Node"
|
||||||
|
|
||||||
def get_response(
|
async def get_response(
|
||||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
) -> KlingText2VideoResponse:
|
) -> KlingText2VideoResponse:
|
||||||
return poll_until_finished(
|
return await poll_until_finished(
|
||||||
auth_kwargs,
|
auth_kwargs,
|
||||||
ApiEndpoint(
|
ApiEndpoint(
|
||||||
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
|
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
|
||||||
@ -483,7 +485,7 @@ class KlingTextToVideoNode(KlingNodeBase):
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
@ -519,17 +521,17 @@ class KlingTextToVideoNode(KlingNodeBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
task_creation_response = initial_operation.execute()
|
task_creation_response = await initial_operation.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
|
|
||||||
task_id = task_creation_response.data.task_id
|
task_id = task_creation_response.data.task_id
|
||||||
final_response = self.get_response(
|
final_response = await self.get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||||
)
|
)
|
||||||
validate_video_result_response(final_response)
|
validate_video_result_response(final_response)
|
||||||
|
|
||||||
video = get_video_from_response(final_response)
|
video = get_video_from_response(final_response)
|
||||||
return video_result_to_node_output(video)
|
return await video_result_to_node_output(video)
|
||||||
|
|
||||||
|
|
||||||
class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
||||||
@ -581,7 +583,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
|||||||
|
|
||||||
DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text."
|
DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text."
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
@ -591,7 +593,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
|
|||||||
unique_id: Optional[str] = None,
|
unique_id: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return super().api_call(
|
return await super().api_call(
|
||||||
model_name=KlingVideoGenModelName.kling_v1,
|
model_name=KlingVideoGenModelName.kling_v1,
|
||||||
cfg_scale=cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
mode=KlingVideoGenMode.std,
|
mode=KlingVideoGenMode.std,
|
||||||
@ -670,10 +672,10 @@ class KlingImage2VideoNode(KlingNodeBase):
|
|||||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||||
DESCRIPTION = "Kling Image to Video Node"
|
DESCRIPTION = "Kling Image to Video Node"
|
||||||
|
|
||||||
def get_response(
|
async def get_response(
|
||||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
) -> KlingImage2VideoResponse:
|
) -> KlingImage2VideoResponse:
|
||||||
return poll_until_finished(
|
return await poll_until_finished(
|
||||||
auth_kwargs,
|
auth_kwargs,
|
||||||
ApiEndpoint(
|
ApiEndpoint(
|
||||||
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
|
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
|
||||||
@ -686,7 +688,7 @@ class KlingImage2VideoNode(KlingNodeBase):
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
start_frame: torch.Tensor,
|
start_frame: torch.Tensor,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -733,17 +735,17 @@ class KlingImage2VideoNode(KlingNodeBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
task_creation_response = initial_operation.execute()
|
task_creation_response = await initial_operation.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.data.task_id
|
task_id = task_creation_response.data.task_id
|
||||||
|
|
||||||
final_response = self.get_response(
|
final_response = await self.get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||||
)
|
)
|
||||||
validate_video_result_response(final_response)
|
validate_video_result_response(final_response)
|
||||||
|
|
||||||
video = get_video_from_response(final_response)
|
video = get_video_from_response(final_response)
|
||||||
return video_result_to_node_output(video)
|
return await video_result_to_node_output(video)
|
||||||
|
|
||||||
|
|
||||||
class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
||||||
@ -798,7 +800,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
|||||||
|
|
||||||
DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image."
|
DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image."
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
start_frame: torch.Tensor,
|
start_frame: torch.Tensor,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -809,7 +811,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
|
|||||||
unique_id: Optional[str] = None,
|
unique_id: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return super().api_call(
|
return await super().api_call(
|
||||||
model_name=KlingVideoGenModelName.kling_v1_5,
|
model_name=KlingVideoGenModelName.kling_v1_5,
|
||||||
start_frame=start_frame,
|
start_frame=start_frame,
|
||||||
cfg_scale=cfg_scale,
|
cfg_scale=cfg_scale,
|
||||||
@ -897,7 +899,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
|||||||
|
|
||||||
DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last."
|
DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last."
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
start_frame: torch.Tensor,
|
start_frame: torch.Tensor,
|
||||||
end_frame: torch.Tensor,
|
end_frame: torch.Tensor,
|
||||||
@ -912,7 +914,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
|
|||||||
mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
|
mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
|
||||||
mode
|
mode
|
||||||
]
|
]
|
||||||
return super().api_call(
|
return await super().api_call(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -964,10 +966,10 @@ class KlingVideoExtendNode(KlingNodeBase):
|
|||||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||||
DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
|
DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
|
||||||
|
|
||||||
def get_response(
|
async def get_response(
|
||||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
) -> KlingVideoExtendResponse:
|
) -> KlingVideoExtendResponse:
|
||||||
return poll_until_finished(
|
return await poll_until_finished(
|
||||||
auth_kwargs,
|
auth_kwargs,
|
||||||
ApiEndpoint(
|
ApiEndpoint(
|
||||||
path=f"{PATH_VIDEO_EXTEND}/{task_id}",
|
path=f"{PATH_VIDEO_EXTEND}/{task_id}",
|
||||||
@ -980,7 +982,7 @@ class KlingVideoExtendNode(KlingNodeBase):
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
@ -1006,17 +1008,17 @@ class KlingVideoExtendNode(KlingNodeBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
task_creation_response = initial_operation.execute()
|
task_creation_response = await initial_operation.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.data.task_id
|
task_id = task_creation_response.data.task_id
|
||||||
|
|
||||||
final_response = self.get_response(
|
final_response = await self.get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||||
)
|
)
|
||||||
validate_video_result_response(final_response)
|
validate_video_result_response(final_response)
|
||||||
|
|
||||||
video = get_video_from_response(final_response)
|
video = get_video_from_response(final_response)
|
||||||
return video_result_to_node_output(video)
|
return await video_result_to_node_output(video)
|
||||||
|
|
||||||
|
|
||||||
class KlingVideoEffectsBase(KlingNodeBase):
|
class KlingVideoEffectsBase(KlingNodeBase):
|
||||||
@ -1025,10 +1027,10 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
|||||||
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
|
||||||
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
RETURN_NAMES = ("VIDEO", "video_id", "duration")
|
||||||
|
|
||||||
def get_response(
|
async def get_response(
|
||||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
) -> KlingVideoEffectsResponse:
|
) -> KlingVideoEffectsResponse:
|
||||||
return poll_until_finished(
|
return await poll_until_finished(
|
||||||
auth_kwargs,
|
auth_kwargs,
|
||||||
ApiEndpoint(
|
ApiEndpoint(
|
||||||
path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
|
path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
|
||||||
@ -1041,7 +1043,7 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
dual_character: bool,
|
dual_character: bool,
|
||||||
effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene,
|
effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene,
|
||||||
@ -1084,17 +1086,17 @@ class KlingVideoEffectsBase(KlingNodeBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
task_creation_response = initial_operation.execute()
|
task_creation_response = await initial_operation.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.data.task_id
|
task_id = task_creation_response.data.task_id
|
||||||
|
|
||||||
final_response = self.get_response(
|
final_response = await self.get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||||
)
|
)
|
||||||
validate_video_result_response(final_response)
|
validate_video_result_response(final_response)
|
||||||
|
|
||||||
video = get_video_from_response(final_response)
|
video = get_video_from_response(final_response)
|
||||||
return video_result_to_node_output(video)
|
return await video_result_to_node_output(video)
|
||||||
|
|
||||||
|
|
||||||
class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
||||||
@ -1142,7 +1144,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
|||||||
RETURN_TYPES = ("VIDEO", "STRING")
|
RETURN_TYPES = ("VIDEO", "STRING")
|
||||||
RETURN_NAMES = ("VIDEO", "duration")
|
RETURN_NAMES = ("VIDEO", "duration")
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image_left: torch.Tensor,
|
image_left: torch.Tensor,
|
||||||
image_right: torch.Tensor,
|
image_right: torch.Tensor,
|
||||||
@ -1153,7 +1155,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
|
|||||||
unique_id: Optional[str] = None,
|
unique_id: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
video, _, duration = super().api_call(
|
video, _, duration = await super().api_call(
|
||||||
dual_character=True,
|
dual_character=True,
|
||||||
effect_scene=effect_scene,
|
effect_scene=effect_scene,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -1208,7 +1210,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
|||||||
|
|
||||||
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene."
|
DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene."
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
effect_scene: KlingSingleImageEffectsScene,
|
effect_scene: KlingSingleImageEffectsScene,
|
||||||
@ -1217,7 +1219,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
|
|||||||
unique_id: Optional[str] = None,
|
unique_id: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return super().api_call(
|
return await super().api_call(
|
||||||
dual_character=False,
|
dual_character=False,
|
||||||
effect_scene=effect_scene,
|
effect_scene=effect_scene,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -1253,11 +1255,11 @@ class KlingLipSyncBase(KlingNodeBase):
|
|||||||
f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
|
f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_response(
|
async def get_response(
|
||||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
) -> KlingLipSyncResponse:
|
) -> KlingLipSyncResponse:
|
||||||
"""Polls the Kling API endpoint until the task reaches a terminal state."""
|
"""Polls the Kling API endpoint until the task reaches a terminal state."""
|
||||||
return poll_until_finished(
|
return await poll_until_finished(
|
||||||
auth_kwargs,
|
auth_kwargs,
|
||||||
ApiEndpoint(
|
ApiEndpoint(
|
||||||
path=f"{PATH_LIP_SYNC}/{task_id}",
|
path=f"{PATH_LIP_SYNC}/{task_id}",
|
||||||
@ -1270,7 +1272,7 @@ class KlingLipSyncBase(KlingNodeBase):
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
video: VideoInput,
|
video: VideoInput,
|
||||||
audio: Optional[AudioInput] = None,
|
audio: Optional[AudioInput] = None,
|
||||||
@ -1287,12 +1289,12 @@ class KlingLipSyncBase(KlingNodeBase):
|
|||||||
self.validate_lip_sync_video(video)
|
self.validate_lip_sync_video(video)
|
||||||
|
|
||||||
# Upload video to Comfy API and get download URL
|
# Upload video to Comfy API and get download URL
|
||||||
video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs)
|
video_url = await upload_video_to_comfyapi(video, auth_kwargs=kwargs)
|
||||||
logging.info("Uploaded video to Comfy API. URL: %s", video_url)
|
logging.info("Uploaded video to Comfy API. URL: %s", video_url)
|
||||||
|
|
||||||
# Upload the audio file to Comfy API and get download URL
|
# Upload the audio file to Comfy API and get download URL
|
||||||
if audio:
|
if audio:
|
||||||
audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
|
audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
|
||||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||||
else:
|
else:
|
||||||
audio_url = None
|
audio_url = None
|
||||||
@ -1319,17 +1321,17 @@ class KlingLipSyncBase(KlingNodeBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
task_creation_response = initial_operation.execute()
|
task_creation_response = await initial_operation.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.data.task_id
|
task_id = task_creation_response.data.task_id
|
||||||
|
|
||||||
final_response = self.get_response(
|
final_response = await self.get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||||
)
|
)
|
||||||
validate_video_result_response(final_response)
|
validate_video_result_response(final_response)
|
||||||
|
|
||||||
video = get_video_from_response(final_response)
|
video = get_video_from_response(final_response)
|
||||||
return video_result_to_node_output(video)
|
return await video_result_to_node_output(video)
|
||||||
|
|
||||||
|
|
||||||
class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
||||||
@ -1357,7 +1359,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
|||||||
|
|
||||||
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
video: VideoInput,
|
video: VideoInput,
|
||||||
audio: AudioInput,
|
audio: AudioInput,
|
||||||
@ -1365,7 +1367,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
|
|||||||
unique_id: Optional[str] = None,
|
unique_id: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return super().api_call(
|
return await super().api_call(
|
||||||
video=video,
|
video=video,
|
||||||
audio=audio,
|
audio=audio,
|
||||||
voice_language=voice_language,
|
voice_language=voice_language,
|
||||||
@ -1469,7 +1471,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
|||||||
|
|
||||||
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
video: VideoInput,
|
video: VideoInput,
|
||||||
text: str,
|
text: str,
|
||||||
@ -1479,7 +1481,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
|
voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
|
||||||
return super().api_call(
|
return await super().api_call(
|
||||||
video=video,
|
video=video,
|
||||||
text=text,
|
text=text,
|
||||||
voice_language=voice_language,
|
voice_language=voice_language,
|
||||||
@ -1533,10 +1535,10 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
|||||||
|
|
||||||
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
|
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
|
||||||
|
|
||||||
def get_response(
|
async def get_response(
|
||||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
) -> KlingVirtualTryOnResponse:
|
) -> KlingVirtualTryOnResponse:
|
||||||
return poll_until_finished(
|
return await poll_until_finished(
|
||||||
auth_kwargs,
|
auth_kwargs,
|
||||||
ApiEndpoint(
|
ApiEndpoint(
|
||||||
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
|
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
|
||||||
@ -1549,7 +1551,7 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
human_image: torch.Tensor,
|
human_image: torch.Tensor,
|
||||||
cloth_image: torch.Tensor,
|
cloth_image: torch.Tensor,
|
||||||
@ -1572,17 +1574,17 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
task_creation_response = initial_operation.execute()
|
task_creation_response = await initial_operation.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.data.task_id
|
task_id = task_creation_response.data.task_id
|
||||||
|
|
||||||
final_response = self.get_response(
|
final_response = await self.get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||||
)
|
)
|
||||||
validate_image_result_response(final_response)
|
validate_image_result_response(final_response)
|
||||||
|
|
||||||
images = get_images_from_response(final_response)
|
images = get_images_from_response(final_response)
|
||||||
return (image_result_to_node_output(images),)
|
return (await image_result_to_node_output(images),)
|
||||||
|
|
||||||
|
|
||||||
class KlingImageGenerationNode(KlingImageGenerationBase):
|
class KlingImageGenerationNode(KlingImageGenerationBase):
|
||||||
@ -1655,13 +1657,13 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
|||||||
|
|
||||||
DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
|
DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
|
||||||
|
|
||||||
def get_response(
|
async def get_response(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
auth_kwargs: Optional[dict[str, str]],
|
auth_kwargs: Optional[dict[str, str]],
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
) -> KlingImageGenerationsResponse:
|
) -> KlingImageGenerationsResponse:
|
||||||
return poll_until_finished(
|
return await poll_until_finished(
|
||||||
auth_kwargs,
|
auth_kwargs,
|
||||||
ApiEndpoint(
|
ApiEndpoint(
|
||||||
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
|
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
|
||||||
@ -1674,7 +1676,7 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
model_name: KlingImageGenModelName,
|
model_name: KlingImageGenModelName,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -1690,7 +1692,11 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
|||||||
):
|
):
|
||||||
self.validate_prompt(prompt, negative_prompt)
|
self.validate_prompt(prompt, negative_prompt)
|
||||||
|
|
||||||
if image is not None:
|
if image is None:
|
||||||
|
image_type = None
|
||||||
|
elif model_name == KlingImageGenModelName.kling_v1:
|
||||||
|
raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.")
|
||||||
|
else:
|
||||||
image = tensor_to_base64_string(image)
|
image = tensor_to_base64_string(image)
|
||||||
|
|
||||||
initial_operation = SynchronousOperation(
|
initial_operation = SynchronousOperation(
|
||||||
@ -1714,17 +1720,17 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
task_creation_response = initial_operation.execute()
|
task_creation_response = await initial_operation.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.data.task_id
|
task_id = task_creation_response.data.task_id
|
||||||
|
|
||||||
final_response = self.get_response(
|
final_response = await self.get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||||
)
|
)
|
||||||
validate_image_result_response(final_response)
|
validate_image_result_response(final_response)
|
||||||
|
|
||||||
images = get_images_from_response(final_response)
|
images = get_images_from_response(final_response)
|
||||||
return (image_result_to_node_output(images),)
|
return (await image_result_to_node_output(images),)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
@ -38,7 +38,7 @@ from comfy_api_nodes.apinode_utils import (
|
|||||||
)
|
)
|
||||||
from comfy.cmd.server import PromptServer
|
from comfy.cmd.server import PromptServer
|
||||||
|
|
||||||
import requests
|
import aiohttp
|
||||||
import torch
|
import torch
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
@ -217,7 +217,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
model: str,
|
||||||
@ -234,19 +234,19 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
|||||||
# handle image_luma_ref
|
# handle image_luma_ref
|
||||||
api_image_ref = None
|
api_image_ref = None
|
||||||
if image_luma_ref is not None:
|
if image_luma_ref is not None:
|
||||||
api_image_ref = self._convert_luma_refs(
|
api_image_ref = await self._convert_luma_refs(
|
||||||
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
|
image_luma_ref, max_refs=4, auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
# handle style_luma_ref
|
# handle style_luma_ref
|
||||||
api_style_ref = None
|
api_style_ref = None
|
||||||
if style_image is not None:
|
if style_image is not None:
|
||||||
api_style_ref = self._convert_style_image(
|
api_style_ref = await self._convert_style_image(
|
||||||
style_image, weight=style_image_weight, auth_kwargs=kwargs,
|
style_image, weight=style_image_weight, auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
# handle character_ref images
|
# handle character_ref images
|
||||||
character_ref = None
|
character_ref = None
|
||||||
if character_image is not None:
|
if character_image is not None:
|
||||||
download_urls = upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
character_image, max_images=4, auth_kwargs=kwargs,
|
character_image, max_images=4, auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
character_ref = LumaCharacterRef(
|
character_ref = LumaCharacterRef(
|
||||||
@ -270,7 +270,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_api: LumaGeneration = operation.execute()
|
response_api: LumaGeneration = await operation.execute()
|
||||||
|
|
||||||
operation = PollingOperation(
|
operation = PollingOperation(
|
||||||
poll_endpoint=ApiEndpoint(
|
poll_endpoint=ApiEndpoint(
|
||||||
@ -286,19 +286,20 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
|||||||
node_id=unique_id,
|
node_id=unique_id,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_poll = operation.execute()
|
response_poll = await operation.execute()
|
||||||
|
|
||||||
img_response = requests.get(response_poll.assets.image)
|
async with aiohttp.ClientSession() as session:
|
||||||
img = process_image_response(img_response)
|
async with session.get(response_poll.assets.image) as img_response:
|
||||||
|
img = process_image_response(await img_response.content.read())
|
||||||
return (img,)
|
return (img,)
|
||||||
|
|
||||||
def _convert_luma_refs(
|
async def _convert_luma_refs(
|
||||||
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||||
):
|
):
|
||||||
luma_urls = []
|
luma_urls = []
|
||||||
ref_count = 0
|
ref_count = 0
|
||||||
for ref in luma_ref.refs:
|
for ref in luma_ref.refs:
|
||||||
download_urls = upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
||||||
)
|
)
|
||||||
luma_urls.append(download_urls[0])
|
luma_urls.append(download_urls[0])
|
||||||
@ -307,13 +308,13 @@ class LumaImageGenerationNode(ComfyNodeABC):
|
|||||||
break
|
break
|
||||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||||
|
|
||||||
def _convert_style_image(
|
async def _convert_style_image(
|
||||||
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||||
):
|
):
|
||||||
chain = LumaReferenceChain(
|
chain = LumaReferenceChain(
|
||||||
first_ref=LumaReference(image=style_image, weight=weight)
|
first_ref=LumaReference(image=style_image, weight=weight)
|
||||||
)
|
)
|
||||||
return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||||
|
|
||||||
|
|
||||||
class LumaImageModifyNode(ComfyNodeABC):
|
class LumaImageModifyNode(ComfyNodeABC):
|
||||||
@ -370,7 +371,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
model: str,
|
||||||
@ -381,7 +382,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# first, upload image
|
# first, upload image
|
||||||
download_urls = upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
image, max_images=1, auth_kwargs=kwargs,
|
image, max_images=1, auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
image_url = download_urls[0]
|
image_url = download_urls[0]
|
||||||
@ -402,7 +403,7 @@ class LumaImageModifyNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_api: LumaGeneration = operation.execute()
|
response_api: LumaGeneration = await operation.execute()
|
||||||
|
|
||||||
operation = PollingOperation(
|
operation = PollingOperation(
|
||||||
poll_endpoint=ApiEndpoint(
|
poll_endpoint=ApiEndpoint(
|
||||||
@ -418,10 +419,11 @@ class LumaImageModifyNode(ComfyNodeABC):
|
|||||||
node_id=unique_id,
|
node_id=unique_id,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_poll = operation.execute()
|
response_poll = await operation.execute()
|
||||||
|
|
||||||
img_response = requests.get(response_poll.assets.image)
|
async with aiohttp.ClientSession() as session:
|
||||||
img = process_image_response(img_response)
|
async with session.get(response_poll.assets.image) as img_response:
|
||||||
|
img = process_image_response(await img_response.content.read())
|
||||||
return (img,)
|
return (img,)
|
||||||
|
|
||||||
|
|
||||||
@ -494,7 +496,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
model: str,
|
||||||
@ -529,7 +531,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_api: LumaGeneration = operation.execute()
|
response_api: LumaGeneration = await operation.execute()
|
||||||
|
|
||||||
if unique_id:
|
if unique_id:
|
||||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||||
@ -549,10 +551,11 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
|
|||||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_poll = operation.execute()
|
response_poll = await operation.execute()
|
||||||
|
|
||||||
vid_response = requests.get(response_poll.assets.video)
|
async with aiohttp.ClientSession() as session:
|
||||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
async with session.get(response_poll.assets.video) as vid_response:
|
||||||
|
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||||
|
|
||||||
|
|
||||||
class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
||||||
@ -626,7 +629,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: str,
|
model: str,
|
||||||
@ -644,7 +647,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
"At least one of first_image and last_image requires an input."
|
"At least one of first_image and last_image requires an input."
|
||||||
)
|
)
|
||||||
keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
|
||||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||||
|
|
||||||
@ -667,7 +670,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_api: LumaGeneration = operation.execute()
|
response_api: LumaGeneration = await operation.execute()
|
||||||
|
|
||||||
if unique_id:
|
if unique_id:
|
||||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
|
||||||
@ -687,12 +690,13 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
|||||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_poll = operation.execute()
|
response_poll = await operation.execute()
|
||||||
|
|
||||||
vid_response = requests.get(response_poll.assets.video)
|
async with aiohttp.ClientSession() as session:
|
||||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
async with session.get(response_poll.assets.video) as vid_response:
|
||||||
|
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||||
|
|
||||||
def _convert_to_keyframes(
|
async def _convert_to_keyframes(
|
||||||
self,
|
self,
|
||||||
first_image: torch.Tensor = None,
|
first_image: torch.Tensor = None,
|
||||||
last_image: torch.Tensor = None,
|
last_image: torch.Tensor = None,
|
||||||
@ -703,12 +707,12 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
|
|||||||
frame0 = None
|
frame0 = None
|
||||||
frame1 = None
|
frame1 = None
|
||||||
if first_image is not None:
|
if first_image is not None:
|
||||||
download_urls = upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||||
)
|
)
|
||||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||||
if last_image is not None:
|
if last_image is not None:
|
||||||
download_urls = upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||||
)
|
)
|
||||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from inspect import cleandoc
|
||||||
from typing import Union
|
from typing import Union
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
@ -10,7 +11,7 @@ from comfy_api_nodes.apis import (
|
|||||||
MinimaxFileRetrieveResponse,
|
MinimaxFileRetrieveResponse,
|
||||||
MinimaxTaskResultResponse,
|
MinimaxTaskResultResponse,
|
||||||
SubjectReferenceItem,
|
SubjectReferenceItem,
|
||||||
Model
|
MiniMaxModel
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.apis.client import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -84,9 +85,8 @@ class MinimaxTextToVideoNode:
|
|||||||
FUNCTION = "generate_video"
|
FUNCTION = "generate_video"
|
||||||
CATEGORY = "api node/video/MiniMax"
|
CATEGORY = "api node/video/MiniMax"
|
||||||
API_NODE = True
|
API_NODE = True
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
def generate_video(
|
async def generate_video(
|
||||||
self,
|
self,
|
||||||
prompt_text,
|
prompt_text,
|
||||||
seed=0,
|
seed=0,
|
||||||
@ -104,12 +104,12 @@ class MinimaxTextToVideoNode:
|
|||||||
# upload image, if passed in
|
# upload image, if passed in
|
||||||
image_url = None
|
image_url = None
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
|
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0]
|
||||||
|
|
||||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||||
subject_reference = None
|
subject_reference = None
|
||||||
if subject is not None:
|
if subject is not None:
|
||||||
subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
|
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0]
|
||||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||||
|
|
||||||
|
|
||||||
@ -121,7 +121,7 @@ class MinimaxTextToVideoNode:
|
|||||||
response_model=MinimaxVideoGenerationResponse,
|
response_model=MinimaxVideoGenerationResponse,
|
||||||
),
|
),
|
||||||
request=MinimaxVideoGenerationRequest(
|
request=MinimaxVideoGenerationRequest(
|
||||||
model=Model(model),
|
model=MiniMaxModel(model),
|
||||||
prompt=prompt_text,
|
prompt=prompt_text,
|
||||||
callback_url=None,
|
callback_url=None,
|
||||||
first_frame_image=image_url,
|
first_frame_image=image_url,
|
||||||
@ -130,7 +130,7 @@ class MinimaxTextToVideoNode:
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response = video_generate_operation.execute()
|
response = await video_generate_operation.execute()
|
||||||
|
|
||||||
task_id = response.task_id
|
task_id = response.task_id
|
||||||
if not task_id:
|
if not task_id:
|
||||||
@ -151,7 +151,7 @@ class MinimaxTextToVideoNode:
|
|||||||
node_id=unique_id,
|
node_id=unique_id,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
task_result = video_generate_operation.execute()
|
task_result = await video_generate_operation.execute()
|
||||||
|
|
||||||
file_id = task_result.file_id
|
file_id = task_result.file_id
|
||||||
if file_id is None:
|
if file_id is None:
|
||||||
@ -167,7 +167,7 @@ class MinimaxTextToVideoNode:
|
|||||||
request=EmptyRequest(),
|
request=EmptyRequest(),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
file_result = file_retrieve_operation.execute()
|
file_result = await file_retrieve_operation.execute()
|
||||||
|
|
||||||
file_url = file_result.file.download_url
|
file_url = file_result.file.download_url
|
||||||
if file_url is None:
|
if file_url is None:
|
||||||
@ -182,7 +182,7 @@ class MinimaxTextToVideoNode:
|
|||||||
message = f"Result URL: {file_url}"
|
message = f"Result URL: {file_url}"
|
||||||
PromptServer.instance.send_progress_text(message, unique_id)
|
PromptServer.instance.send_progress_text(message, unique_id)
|
||||||
|
|
||||||
video_io = download_url_to_bytesio(file_url)
|
video_io = await download_url_to_bytesio(file_url)
|
||||||
if video_io is None:
|
if video_io is None:
|
||||||
error_msg = f"Failed to download video from {file_url}"
|
error_msg = f"Failed to download video from {file_url}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
@ -251,7 +251,6 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
|
|||||||
FUNCTION = "generate_video"
|
FUNCTION = "generate_video"
|
||||||
CATEGORY = "api node/video/MiniMax"
|
CATEGORY = "api node/video/MiniMax"
|
||||||
API_NODE = True
|
API_NODE = True
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
|
|
||||||
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
||||||
@ -313,7 +312,181 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
|
|||||||
FUNCTION = "generate_video"
|
FUNCTION = "generate_video"
|
||||||
CATEGORY = "api node/video/MiniMax"
|
CATEGORY = "api node/video/MiniMax"
|
||||||
API_NODE = True
|
API_NODE = True
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
|
class MinimaxHailuoVideoNode:
|
||||||
|
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"prompt_text": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"multiline": True,
|
||||||
|
"default": "",
|
||||||
|
"tooltip": "Text prompt to guide the video generation.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"seed": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 0xFFFFFFFFFFFFFFFF,
|
||||||
|
"control_after_generate": True,
|
||||||
|
"tooltip": "The random seed used for creating the noise.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"first_frame_image": (
|
||||||
|
IO.IMAGE,
|
||||||
|
{
|
||||||
|
"tooltip": "Optional image to use as the first frame to generate a video."
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"prompt_optimizer": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"tooltip": "Optimize prompt to improve generation quality when needed.",
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"duration": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"tooltip": "The length of the output video in seconds.",
|
||||||
|
"default": 6,
|
||||||
|
"options": [6, 10],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"resolution": (
|
||||||
|
IO.COMBO,
|
||||||
|
{
|
||||||
|
"tooltip": "The dimensions of the video display. "
|
||||||
|
"1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels.",
|
||||||
|
"default": "768P",
|
||||||
|
"options": ["768P", "1080P"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"hidden": {
|
||||||
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
|
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||||
|
"unique_id": "UNIQUE_ID",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("VIDEO",)
|
||||||
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
|
FUNCTION = "generate_video"
|
||||||
|
CATEGORY = "api node/video/MiniMax"
|
||||||
|
API_NODE = True
|
||||||
|
|
||||||
|
async def generate_video(
|
||||||
|
self,
|
||||||
|
prompt_text,
|
||||||
|
seed=0,
|
||||||
|
first_frame_image: torch.Tensor=None, # used for ImageToVideo
|
||||||
|
prompt_optimizer=True,
|
||||||
|
duration=6,
|
||||||
|
resolution="768P",
|
||||||
|
model="MiniMax-Hailuo-02",
|
||||||
|
unique_id: Union[str, None]=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if first_frame_image is None:
|
||||||
|
validate_string(prompt_text, field_name="prompt_text")
|
||||||
|
|
||||||
|
if model == "MiniMax-Hailuo-02" and resolution.upper() == "1080P" and duration != 6:
|
||||||
|
raise Exception(
|
||||||
|
"When model is MiniMax-Hailuo-02 and resolution is 1080P, duration is limited to 6 seconds."
|
||||||
|
)
|
||||||
|
|
||||||
|
# upload image, if passed in
|
||||||
|
image_url = None
|
||||||
|
if first_frame_image is not None:
|
||||||
|
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=kwargs))[0]
|
||||||
|
|
||||||
|
video_generate_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/minimax/video_generation",
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=MinimaxVideoGenerationRequest,
|
||||||
|
response_model=MinimaxVideoGenerationResponse,
|
||||||
|
),
|
||||||
|
request=MinimaxVideoGenerationRequest(
|
||||||
|
model=MiniMaxModel(model),
|
||||||
|
prompt=prompt_text,
|
||||||
|
callback_url=None,
|
||||||
|
first_frame_image=image_url,
|
||||||
|
prompt_optimizer=prompt_optimizer,
|
||||||
|
duration=duration,
|
||||||
|
resolution=resolution,
|
||||||
|
),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
response = await video_generate_operation.execute()
|
||||||
|
|
||||||
|
task_id = response.task_id
|
||||||
|
if not task_id:
|
||||||
|
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||||
|
|
||||||
|
average_duration = 120 if resolution == "768P" else 240
|
||||||
|
video_generate_operation = PollingOperation(
|
||||||
|
poll_endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/minimax/query/video_generation",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=MinimaxTaskResultResponse,
|
||||||
|
query_params={"task_id": task_id},
|
||||||
|
),
|
||||||
|
completed_statuses=["Success"],
|
||||||
|
failed_statuses=["Fail"],
|
||||||
|
status_extractor=lambda x: x.status.value,
|
||||||
|
estimated_duration=average_duration,
|
||||||
|
node_id=unique_id,
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
task_result = await video_generate_operation.execute()
|
||||||
|
|
||||||
|
file_id = task_result.file_id
|
||||||
|
if file_id is None:
|
||||||
|
raise Exception("Request was not successful. Missing file ID.")
|
||||||
|
file_retrieve_operation = SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path="/proxy/minimax/files/retrieve",
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=MinimaxFileRetrieveResponse,
|
||||||
|
query_params={"file_id": int(file_id)},
|
||||||
|
),
|
||||||
|
request=EmptyRequest(),
|
||||||
|
auth_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
file_result = await file_retrieve_operation.execute()
|
||||||
|
|
||||||
|
file_url = file_result.file.download_url
|
||||||
|
if file_url is None:
|
||||||
|
raise Exception(
|
||||||
|
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||||
|
)
|
||||||
|
logging.info(f"Generated video URL: {file_url}")
|
||||||
|
if unique_id:
|
||||||
|
if hasattr(file_result.file, "backup_download_url"):
|
||||||
|
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||||
|
else:
|
||||||
|
message = f"Result URL: {file_url}"
|
||||||
|
PromptServer.instance.send_progress_text(message, unique_id)
|
||||||
|
|
||||||
|
video_io = await download_url_to_bytesio(file_url)
|
||||||
|
if video_io is None:
|
||||||
|
error_msg = f"Failed to download video from {file_url}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise Exception(error_msg)
|
||||||
|
return (VideoFromFile(video_io),)
|
||||||
|
|
||||||
|
|
||||||
# A dictionary that contains all nodes you want to export with their names
|
# A dictionary that contains all nodes you want to export with their names
|
||||||
@ -322,6 +495,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"MinimaxTextToVideoNode": MinimaxTextToVideoNode,
|
"MinimaxTextToVideoNode": MinimaxTextToVideoNode,
|
||||||
"MinimaxImageToVideoNode": MinimaxImageToVideoNode,
|
"MinimaxImageToVideoNode": MinimaxImageToVideoNode,
|
||||||
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
|
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
|
||||||
|
"MinimaxHailuoVideoNode": MinimaxHailuoVideoNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
||||||
@ -329,4 +503,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"MinimaxTextToVideoNode": "MiniMax Text to Video",
|
"MinimaxTextToVideoNode": "MiniMax Text to Video",
|
||||||
"MinimaxImageToVideoNode": "MiniMax Image to Video",
|
"MinimaxImageToVideoNode": "MiniMax Image to Video",
|
||||||
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
|
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
|
||||||
|
"MinimaxHailuoVideoNode": "MiniMax Hailuo Video",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Optional, TypeVar
|
from typing import Any, Callable, Optional, TypeVar
|
||||||
import random
|
|
||||||
import torch
|
import torch
|
||||||
from comfy_api_nodes.util.validation_utils import (
|
from comfy_api_nodes.util.validation_utils import (
|
||||||
get_image_dimensions,
|
get_image_dimensions,
|
||||||
@ -95,14 +94,14 @@ def get_video_url_from_response(response) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def poll_until_finished(
|
async def poll_until_finished(
|
||||||
auth_kwargs: dict[str, str],
|
auth_kwargs: dict[str, str],
|
||||||
api_endpoint: ApiEndpoint[Any, R],
|
api_endpoint: ApiEndpoint[Any, R],
|
||||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
) -> R:
|
) -> R:
|
||||||
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
|
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
|
||||||
return PollingOperation(
|
return await PollingOperation(
|
||||||
poll_endpoint=api_endpoint,
|
poll_endpoint=api_endpoint,
|
||||||
completed_statuses=[
|
completed_statuses=[
|
||||||
"completed",
|
"completed",
|
||||||
@ -208,20 +207,29 @@ def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
|
|||||||
def _validate_video_dimensions(width: int, height: int) -> None:
|
def _validate_video_dimensions(width: int, height: int) -> None:
|
||||||
"""Validates video dimensions meet Moonvalley V2V requirements."""
|
"""Validates video dimensions meet Moonvalley V2V requirements."""
|
||||||
supported_resolutions = {
|
supported_resolutions = {
|
||||||
(1920, 1080), (1080, 1920), (1152, 1152),
|
(1920, 1080),
|
||||||
(1536, 1152), (1152, 1536)
|
(1080, 1920),
|
||||||
|
(1152, 1152),
|
||||||
|
(1536, 1152),
|
||||||
|
(1152, 1536),
|
||||||
}
|
}
|
||||||
|
|
||||||
if (width, height) not in supported_resolutions:
|
if (width, height) not in supported_resolutions:
|
||||||
supported_list = ', '.join([f'{w}x{h}' for w, h in sorted(supported_resolutions)])
|
supported_list = ", ".join(
|
||||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
|
||||||
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_container_format(video: VideoInput) -> None:
|
def _validate_container_format(video: VideoInput) -> None:
|
||||||
"""Validates video container format is MP4."""
|
"""Validates video container format is MP4."""
|
||||||
container_format = video.get_container_format()
|
container_format = video.get_container_format()
|
||||||
if container_format not in ['mp4', 'mov,mp4,m4a,3gp,3g2,mj2']:
|
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
raise ValueError(
|
||||||
|
f"Only MP4 container format supported. Got: {container_format}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||||
@ -244,7 +252,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
|||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||||
"""
|
"""
|
||||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||||
@ -302,7 +309,9 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|||||||
# Calculate target frame count that's divisible by 16
|
# Calculate target frame count that's divisible by 16
|
||||||
fps = input_container.streams.video[0].average_rate
|
fps = input_container.streams.video[0].average_rate
|
||||||
estimated_frames = int(duration_sec * fps)
|
estimated_frames = int(duration_sec * fps)
|
||||||
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
|
target_frames = (
|
||||||
|
estimated_frames // 16
|
||||||
|
) * 16 # Round down to nearest multiple of 16
|
||||||
|
|
||||||
if target_frames == 0:
|
if target_frames == 0:
|
||||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||||
@ -394,10 +403,10 @@ class BaseMoonvalleyVideoNode:
|
|||||||
else:
|
else:
|
||||||
return control_map["Motion Transfer"]
|
return control_map["Motion Transfer"]
|
||||||
|
|
||||||
def get_response(
|
async def get_response(
|
||||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
) -> MoonvalleyPromptResponse:
|
) -> MoonvalleyPromptResponse:
|
||||||
return poll_until_finished(
|
return await poll_until_finished(
|
||||||
auth_kwargs,
|
auth_kwargs,
|
||||||
ApiEndpoint(
|
ApiEndpoint(
|
||||||
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
|
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
|
||||||
@ -424,7 +433,7 @@ class BaseMoonvalleyVideoNode:
|
|||||||
MoonvalleyTextToVideoInferenceParams,
|
MoonvalleyTextToVideoInferenceParams,
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts",
|
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||||
),
|
),
|
||||||
"resolution": (
|
"resolution": (
|
||||||
IO.COMBO,
|
IO.COMBO,
|
||||||
@ -441,12 +450,11 @@ class BaseMoonvalleyVideoNode:
|
|||||||
"tooltip": "Resolution of the output video",
|
"tooltip": "Resolution of the output video",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
# "length": (IO.COMBO,{"options":['5s','10s'], "default": '5s'}),
|
|
||||||
"prompt_adherence": model_field_to_node_input(
|
"prompt_adherence": model_field_to_node_input(
|
||||||
IO.FLOAT,
|
IO.FLOAT,
|
||||||
MoonvalleyTextToVideoInferenceParams,
|
MoonvalleyTextToVideoInferenceParams,
|
||||||
"guidance_scale",
|
"guidance_scale",
|
||||||
default=7.0,
|
default=10.0,
|
||||||
step=1,
|
step=1,
|
||||||
min=1,
|
min=1,
|
||||||
max=20,
|
max=20,
|
||||||
@ -455,13 +463,12 @@ class BaseMoonvalleyVideoNode:
|
|||||||
IO.INT,
|
IO.INT,
|
||||||
MoonvalleyTextToVideoInferenceParams,
|
MoonvalleyTextToVideoInferenceParams,
|
||||||
"seed",
|
"seed",
|
||||||
default=random.randint(0, 2**32 - 1),
|
default=9,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967295,
|
max=4294967295,
|
||||||
step=1,
|
step=1,
|
||||||
display="number",
|
display="number",
|
||||||
tooltip="Random seed value",
|
tooltip="Random seed value",
|
||||||
control_after_generate=True,
|
|
||||||
),
|
),
|
||||||
"steps": model_field_to_node_input(
|
"steps": model_field_to_node_input(
|
||||||
IO.INT,
|
IO.INT,
|
||||||
@ -507,7 +514,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
RETURN_NAMES = ("video",)
|
RETURN_NAMES = ("video",)
|
||||||
DESCRIPTION = "Moonvalley Marey Image to Video Node"
|
DESCRIPTION = "Moonvalley Marey Image to Video Node"
|
||||||
|
|
||||||
def generate(
|
async def generate(
|
||||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||||
):
|
):
|
||||||
image = kwargs.get("image", None)
|
image = kwargs.get("image", None)
|
||||||
@ -532,8 +539,10 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||||
mime_type = "image/png"
|
mime_type = "image/png"
|
||||||
|
|
||||||
image_url = upload_images_to_comfyapi(
|
image_url = (
|
||||||
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
|
await upload_images_to_comfyapi(
|
||||||
|
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
|
||||||
|
)
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
request = MoonvalleyTextToVideoRequest(
|
request = MoonvalleyTextToVideoRequest(
|
||||||
@ -549,14 +558,14 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
request=request,
|
request=request,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
task_creation_response = initial_operation.execute()
|
task_creation_response = await initial_operation.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.id
|
task_id = task_creation_response.id
|
||||||
|
|
||||||
final_response = self.get_response(
|
final_response = await self.get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||||
)
|
)
|
||||||
video = download_url_to_video_output(final_response.output_url)
|
video = await download_url_to_video_output(final_response.output_url)
|
||||||
return (video,)
|
return (video,)
|
||||||
|
|
||||||
|
|
||||||
@ -570,17 +579,39 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"prompt": model_field_to_node_input(
|
"prompt": model_field_to_node_input(
|
||||||
IO.STRING, MoonvalleyVideoToVideoRequest, "prompt_text",
|
IO.STRING,
|
||||||
multiline=True
|
MoonvalleyVideoToVideoRequest,
|
||||||
|
"prompt_text",
|
||||||
|
multiline=True,
|
||||||
),
|
),
|
||||||
"negative_prompt": model_field_to_node_input(
|
"negative_prompt": model_field_to_node_input(
|
||||||
IO.STRING,
|
IO.STRING,
|
||||||
MoonvalleyVideoToVideoInferenceParams,
|
MoonvalleyVideoToVideoInferenceParams,
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="low-poly, flat shader, bad rigging, stiff animation, uncanny eyes, low-quality textures, looping glitch, cheap effect, overbloom, bloom spam, default lighting, game asset, stiff face, ugly specular, AI artifacts"
|
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||||
|
),
|
||||||
|
"seed": model_field_to_node_input(
|
||||||
|
IO.INT,
|
||||||
|
MoonvalleyVideoToVideoInferenceParams,
|
||||||
|
"seed",
|
||||||
|
default=9,
|
||||||
|
min=0,
|
||||||
|
max=4294967295,
|
||||||
|
step=1,
|
||||||
|
display="number",
|
||||||
|
tooltip="Random seed value",
|
||||||
|
control_after_generate=False,
|
||||||
|
),
|
||||||
|
"prompt_adherence": model_field_to_node_input(
|
||||||
|
IO.FLOAT,
|
||||||
|
MoonvalleyVideoToVideoInferenceParams,
|
||||||
|
"guidance_scale",
|
||||||
|
default=10.0,
|
||||||
|
step=1,
|
||||||
|
min=1,
|
||||||
|
max=20,
|
||||||
),
|
),
|
||||||
"seed": model_field_to_node_input(IO.INT,MoonvalleyVideoToVideoInferenceParams, "seed", default=random.randint(0, 2**32 - 1), min=0, max=4294967295, step=1, display="number", tooltip="Random seed value", control_after_generate=True),
|
|
||||||
},
|
},
|
||||||
"hidden": {
|
"hidden": {
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||||
@ -588,7 +619,14 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
"unique_id": "UNIQUE_ID",
|
"unique_id": "UNIQUE_ID",
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"video": (IO.VIDEO, {"default": "", "multiline": False, "tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported."}),
|
"video": (
|
||||||
|
IO.VIDEO,
|
||||||
|
{
|
||||||
|
"default": "",
|
||||||
|
"multiline": False,
|
||||||
|
"tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
|
||||||
|
},
|
||||||
|
),
|
||||||
"control_type": (
|
"control_type": (
|
||||||
["Motion Transfer", "Pose Transfer"],
|
["Motion Transfer", "Pose Transfer"],
|
||||||
{"default": "Motion Transfer"},
|
{"default": "Motion Transfer"},
|
||||||
@ -602,17 +640,24 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
"max": 100,
|
"max": 100,
|
||||||
"tooltip": "Only used if control_type is 'Motion Transfer'",
|
"tooltip": "Only used if control_type is 'Motion Transfer'",
|
||||||
},
|
},
|
||||||
)
|
),
|
||||||
}
|
"image": model_field_to_node_input(
|
||||||
|
IO.IMAGE,
|
||||||
|
MoonvalleyTextToVideoRequest,
|
||||||
|
"image_url",
|
||||||
|
tooltip="The reference image used to generate the video",
|
||||||
|
),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("VIDEO",)
|
RETURN_TYPES = ("VIDEO",)
|
||||||
RETURN_NAMES = ("video",)
|
RETURN_NAMES = ("video",)
|
||||||
|
|
||||||
def generate(
|
async def generate(
|
||||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||||
):
|
):
|
||||||
video = kwargs.get("video")
|
video = kwargs.get("video")
|
||||||
|
image = kwargs.get("image", None)
|
||||||
|
|
||||||
if not video:
|
if not video:
|
||||||
raise MoonvalleyApiError("video is required")
|
raise MoonvalleyApiError("video is required")
|
||||||
@ -620,8 +665,16 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
video_url = ""
|
video_url = ""
|
||||||
if video:
|
if video:
|
||||||
validated_video = validate_video_to_video_input(video)
|
validated_video = validate_video_to_video_input(video)
|
||||||
video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs)
|
video_url = await upload_video_to_comfyapi(
|
||||||
|
validated_video, auth_kwargs=kwargs
|
||||||
|
)
|
||||||
|
mime_type = "image/png"
|
||||||
|
|
||||||
|
if not image is None:
|
||||||
|
validate_input_image(image, with_frame_conditioning=True)
|
||||||
|
image_url = await upload_images_to_comfyapi(
|
||||||
|
image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type
|
||||||
|
)
|
||||||
control_type = kwargs.get("control_type")
|
control_type = kwargs.get("control_type")
|
||||||
motion_intensity = kwargs.get("motion_intensity")
|
motion_intensity = kwargs.get("motion_intensity")
|
||||||
|
|
||||||
@ -631,12 +684,12 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
# Only include motion_intensity for Motion Transfer
|
# Only include motion_intensity for Motion Transfer
|
||||||
control_params = {}
|
control_params = {}
|
||||||
if control_type == "Motion Transfer" and motion_intensity is not None:
|
if control_type == "Motion Transfer" and motion_intensity is not None:
|
||||||
control_params['motion_intensity'] = motion_intensity
|
control_params["motion_intensity"] = motion_intensity
|
||||||
|
|
||||||
inference_params=MoonvalleyVideoToVideoInferenceParams(
|
inference_params = MoonvalleyVideoToVideoInferenceParams(
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
seed=kwargs.get("seed"),
|
seed=kwargs.get("seed"),
|
||||||
control_params=control_params
|
control_params=control_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
control = self.parseControlParameter(control_type)
|
control = self.parseControlParameter(control_type)
|
||||||
@ -647,6 +700,7 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
prompt_text=prompt,
|
prompt_text=prompt,
|
||||||
inference_params=inference_params,
|
inference_params=inference_params,
|
||||||
)
|
)
|
||||||
|
request.image_url = image_url if not image is None else None
|
||||||
|
|
||||||
initial_operation = SynchronousOperation(
|
initial_operation = SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
@ -658,15 +712,15 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
request=request,
|
request=request,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
task_creation_response = initial_operation.execute()
|
task_creation_response = await initial_operation.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.id
|
task_id = task_creation_response.id
|
||||||
|
|
||||||
final_response = self.get_response(
|
final_response = await self.get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||||
)
|
)
|
||||||
|
|
||||||
video = download_url_to_video_output(final_response.output_url)
|
video = await download_url_to_video_output(final_response.output_url)
|
||||||
|
|
||||||
return (video,)
|
return (video,)
|
||||||
|
|
||||||
@ -688,21 +742,21 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
del input_types["optional"][param]
|
del input_types["optional"][param]
|
||||||
return input_types
|
return input_types
|
||||||
|
|
||||||
def generate(
|
async def generate(
|
||||||
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
|
||||||
):
|
):
|
||||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||||
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
|
||||||
|
|
||||||
inference_params=MoonvalleyTextToVideoInferenceParams(
|
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
steps=kwargs.get("steps"),
|
steps=kwargs.get("steps"),
|
||||||
seed=kwargs.get("seed"),
|
seed=kwargs.get("seed"),
|
||||||
guidance_scale=kwargs.get("prompt_adherence"),
|
guidance_scale=kwargs.get("prompt_adherence"),
|
||||||
num_frames=128,
|
num_frames=128,
|
||||||
width=width_height.get("width"),
|
width=width_height.get("width"),
|
||||||
height=width_height.get("height"),
|
height=width_height.get("height"),
|
||||||
)
|
)
|
||||||
request = MoonvalleyTextToVideoRequest(
|
request = MoonvalleyTextToVideoRequest(
|
||||||
prompt_text=prompt, inference_params=inference_params
|
prompt_text=prompt, inference_params=inference_params
|
||||||
)
|
)
|
||||||
@ -717,15 +771,15 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
|
|||||||
request=request,
|
request=request,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
task_creation_response = initial_operation.execute()
|
task_creation_response = await initial_operation.execute()
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.id
|
task_id = task_creation_response.id
|
||||||
|
|
||||||
final_response = self.get_response(
|
final_response = await self.get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||||
)
|
)
|
||||||
|
|
||||||
video = download_url_to_video_output(final_response.output_url)
|
video = await download_url_to_video_output(final_response.output_url)
|
||||||
return (video,)
|
return (video,)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -80,6 +80,9 @@ class SupportedOpenAIModel(str, Enum):
|
|||||||
gpt_4_1 = "gpt-4.1"
|
gpt_4_1 = "gpt-4.1"
|
||||||
gpt_4_1_mini = "gpt-4.1-mini"
|
gpt_4_1_mini = "gpt-4.1-mini"
|
||||||
gpt_4_1_nano = "gpt-4.1-nano"
|
gpt_4_1_nano = "gpt-4.1-nano"
|
||||||
|
gpt_5 = "gpt-5"
|
||||||
|
gpt_5_mini = "gpt-5-mini"
|
||||||
|
gpt_5_nano = "gpt-5-nano"
|
||||||
|
|
||||||
|
|
||||||
class OpenAIDalle2(ComfyNodeABC):
|
class OpenAIDalle2(ComfyNodeABC):
|
||||||
@ -163,7 +166,7 @@ class OpenAIDalle2(ComfyNodeABC):
|
|||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
API_NODE = True
|
API_NODE = True
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
seed=0,
|
seed=0,
|
||||||
@ -233,9 +236,9 @@ class OpenAIDalle2(ComfyNodeABC):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = operation.execute()
|
response = await operation.execute()
|
||||||
|
|
||||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
|
||||||
return (img_tensor,)
|
return (img_tensor,)
|
||||||
|
|
||||||
|
|
||||||
@ -311,7 +314,7 @@ class OpenAIDalle3(ComfyNodeABC):
|
|||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
API_NODE = True
|
API_NODE = True
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
seed=0,
|
seed=0,
|
||||||
@ -343,9 +346,9 @@ class OpenAIDalle3(ComfyNodeABC):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = operation.execute()
|
response = await operation.execute()
|
||||||
|
|
||||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
|
||||||
return (img_tensor,)
|
return (img_tensor,)
|
||||||
|
|
||||||
|
|
||||||
@ -446,7 +449,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
|||||||
DESCRIPTION = cleandoc(__doc__ or "")
|
DESCRIPTION = cleandoc(__doc__ or "")
|
||||||
API_NODE = True
|
API_NODE = True
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
seed=0,
|
seed=0,
|
||||||
@ -464,8 +467,6 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
|||||||
path = "/proxy/openai/images/generations"
|
path = "/proxy/openai/images/generations"
|
||||||
content_type = "application/json"
|
content_type = "application/json"
|
||||||
request_class = OpenAIImageGenerationRequest
|
request_class = OpenAIImageGenerationRequest
|
||||||
img_binaries = []
|
|
||||||
mask_binary = None
|
|
||||||
files = []
|
files = []
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
@ -484,14 +485,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
|||||||
img_byte_arr = io.BytesIO()
|
img_byte_arr = io.BytesIO()
|
||||||
img.save(img_byte_arr, format="PNG")
|
img.save(img_byte_arr, format="PNG")
|
||||||
img_byte_arr.seek(0)
|
img_byte_arr.seek(0)
|
||||||
img_binary = img_byte_arr
|
|
||||||
img_binary.name = f"image_{i}.png"
|
|
||||||
|
|
||||||
img_binaries.append(img_binary)
|
|
||||||
if batch_size == 1:
|
if batch_size == 1:
|
||||||
files.append(("image", img_binary))
|
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||||
else:
|
else:
|
||||||
files.append(("image[]", img_binary))
|
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
if image is None:
|
if image is None:
|
||||||
@ -511,9 +509,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
|||||||
mask_img_byte_arr = io.BytesIO()
|
mask_img_byte_arr = io.BytesIO()
|
||||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||||
mask_img_byte_arr.seek(0)
|
mask_img_byte_arr.seek(0)
|
||||||
mask_binary = mask_img_byte_arr
|
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||||
mask_binary.name = "mask.png"
|
|
||||||
files.append(("mask", mask_binary))
|
|
||||||
|
|
||||||
# Build the operation
|
# Build the operation
|
||||||
operation = SynchronousOperation(
|
operation = SynchronousOperation(
|
||||||
@ -537,9 +533,9 @@ class OpenAIGPTImage1(ComfyNodeABC):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = operation.execute()
|
response = await operation.execute()
|
||||||
|
|
||||||
img_tensor = validate_and_cast_response(response, node_id=unique_id)
|
img_tensor = await validate_and_cast_response(response, node_id=unique_id)
|
||||||
return (img_tensor,)
|
return (img_tensor,)
|
||||||
|
|
||||||
|
|
||||||
@ -623,7 +619,7 @@ class OpenAIChatNode(OpenAITextNode):
|
|||||||
|
|
||||||
DESCRIPTION = "Generate text responses from an OpenAI model."
|
DESCRIPTION = "Generate text responses from an OpenAI model."
|
||||||
|
|
||||||
def get_result_response(
|
async def get_result_response(
|
||||||
self,
|
self,
|
||||||
response_id: str,
|
response_id: str,
|
||||||
include: Optional[list[Includable]] = None,
|
include: Optional[list[Includable]] = None,
|
||||||
@ -639,7 +635,7 @@ class OpenAIChatNode(OpenAITextNode):
|
|||||||
creation above for more information.
|
creation above for more information.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return PollingOperation(
|
return await PollingOperation(
|
||||||
poll_endpoint=ApiEndpoint(
|
poll_endpoint=ApiEndpoint(
|
||||||
path=f"{RESPONSES_ENDPOINT}/{response_id}",
|
path=f"{RESPONSES_ENDPOINT}/{response_id}",
|
||||||
method=HttpMethod.GET,
|
method=HttpMethod.GET,
|
||||||
@ -784,7 +780,7 @@ class OpenAIChatNode(OpenAITextNode):
|
|||||||
|
|
||||||
self.history[session_id] = new_history
|
self.history[session_id] = new_history
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
persist_context: bool,
|
persist_context: bool,
|
||||||
@ -815,7 +811,7 @@ class OpenAIChatNode(OpenAITextNode):
|
|||||||
previous_response_id = None
|
previous_response_id = None
|
||||||
|
|
||||||
# Create response
|
# Create response
|
||||||
create_response = SynchronousOperation(
|
create_response = await SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path=RESPONSES_ENDPOINT,
|
path=RESPONSES_ENDPOINT,
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
@ -848,7 +844,7 @@ class OpenAIChatNode(OpenAITextNode):
|
|||||||
response_id = create_response.id
|
response_id = create_response.id
|
||||||
|
|
||||||
# Get result output
|
# Get result output
|
||||||
result_response = self.get_result_response(response_id, auth_kwargs=kwargs)
|
result_response = await self.get_result_response(response_id, auth_kwargs=kwargs)
|
||||||
output_text = self.parse_output_text_from_response(result_response)
|
output_text = self.parse_output_text_from_response(result_response)
|
||||||
|
|
||||||
# Update history
|
# Update history
|
||||||
@ -1002,7 +998,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"OpenAIDalle2": "OpenAI DALL·E 2",
|
"OpenAIDalle2": "OpenAI DALL·E 2",
|
||||||
"OpenAIDalle3": "OpenAI DALL·E 3",
|
"OpenAIDalle3": "OpenAI DALL·E 3",
|
||||||
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
"OpenAIGPTImage1": "OpenAI GPT Image 1",
|
||||||
"OpenAIChatNode": "OpenAI Chat",
|
"OpenAIChatNode": "OpenAI ChatGPT",
|
||||||
"OpenAIInputFiles": "OpenAI Chat Input Files",
|
"OpenAIInputFiles": "OpenAI ChatGPT Input Files",
|
||||||
"OpenAIChatConfig": "OpenAI Chat Advanced Options",
|
"OpenAIChatConfig": "OpenAI ChatGPT Advanced Options",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -122,7 +122,7 @@ class PikaNodeBase(ComfyNodeABC):
|
|||||||
FUNCTION = "api_call"
|
FUNCTION = "api_call"
|
||||||
RETURN_TYPES = ("VIDEO",)
|
RETURN_TYPES = ("VIDEO",)
|
||||||
|
|
||||||
def poll_for_task_status(
|
async def poll_for_task_status(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
@ -152,9 +152,9 @@ class PikaNodeBase(ComfyNodeABC):
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
estimated_duration=60
|
estimated_duration=60
|
||||||
)
|
)
|
||||||
return polling_operation.execute()
|
return await polling_operation.execute()
|
||||||
|
|
||||||
def execute_task(
|
async def execute_task(
|
||||||
self,
|
self,
|
||||||
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
auth_kwargs: Optional[dict[str, str]] = None,
|
||||||
@ -169,14 +169,14 @@ class PikaNodeBase(ComfyNodeABC):
|
|||||||
Returns:
|
Returns:
|
||||||
A tuple containing the video file as a VIDEO output.
|
A tuple containing the video file as a VIDEO output.
|
||||||
"""
|
"""
|
||||||
initial_response = initial_operation.execute()
|
initial_response = await initial_operation.execute()
|
||||||
if not is_valid_initial_response(initial_response):
|
if not is_valid_initial_response(initial_response):
|
||||||
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
|
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise PikaApiError(error_msg)
|
raise PikaApiError(error_msg)
|
||||||
|
|
||||||
task_id = initial_response.video_id
|
task_id = initial_response.video_id
|
||||||
final_response = self.poll_for_task_status(task_id, auth_kwargs)
|
final_response = await self.poll_for_task_status(task_id, auth_kwargs)
|
||||||
if not is_valid_video_response(final_response):
|
if not is_valid_video_response(final_response):
|
||||||
error_msg = (
|
error_msg = (
|
||||||
f"Pika task {task_id} succeeded but no video data found in response."
|
f"Pika task {task_id} succeeded but no video data found in response."
|
||||||
@ -187,7 +187,7 @@ class PikaNodeBase(ComfyNodeABC):
|
|||||||
video_url = str(final_response.url)
|
video_url = str(final_response.url)
|
||||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||||
|
|
||||||
return (download_url_to_video_output(video_url),)
|
return (await download_url_to_video_output(video_url),)
|
||||||
|
|
||||||
|
|
||||||
class PikaImageToVideoV2_2(PikaNodeBase):
|
class PikaImageToVideoV2_2(PikaNodeBase):
|
||||||
@ -212,7 +212,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
|||||||
|
|
||||||
DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video."
|
DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video."
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
prompt_text: str,
|
prompt_text: str,
|
||||||
@ -251,7 +251,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||||
|
|
||||||
|
|
||||||
class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
||||||
@ -281,7 +281,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
|||||||
|
|
||||||
DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video."
|
DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video."
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt_text: str,
|
prompt_text: str,
|
||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
@ -311,7 +311,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
|
|||||||
content_type="application/x-www-form-urlencoded",
|
content_type="application/x-www-form-urlencoded",
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||||
|
|
||||||
|
|
||||||
class PikaScenesV2_2(PikaNodeBase):
|
class PikaScenesV2_2(PikaNodeBase):
|
||||||
@ -361,7 +361,7 @@ class PikaScenesV2_2(PikaNodeBase):
|
|||||||
|
|
||||||
DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them."
|
DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them."
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt_text: str,
|
prompt_text: str,
|
||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
@ -420,7 +420,7 @@ class PikaScenesV2_2(PikaNodeBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||||
|
|
||||||
|
|
||||||
class PikAdditionsNode(PikaNodeBase):
|
class PikAdditionsNode(PikaNodeBase):
|
||||||
@ -462,7 +462,7 @@ class PikAdditionsNode(PikaNodeBase):
|
|||||||
|
|
||||||
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
|
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
video: VideoInput,
|
video: VideoInput,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
@ -481,10 +481,10 @@ class PikAdditionsNode(PikaNodeBase):
|
|||||||
image_bytes_io = tensor_to_bytesio(image)
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
image_bytes_io.seek(0)
|
image_bytes_io.seek(0)
|
||||||
|
|
||||||
pika_files = [
|
pika_files = {
|
||||||
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||||
("image", ("image.png", image_bytes_io, "image/png")),
|
"image": ("image.png", image_bytes_io, "image/png"),
|
||||||
]
|
}
|
||||||
|
|
||||||
# Prepare non-file data
|
# Prepare non-file data
|
||||||
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||||
@ -506,7 +506,7 @@ class PikAdditionsNode(PikaNodeBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||||
|
|
||||||
|
|
||||||
class PikaSwapsNode(PikaNodeBase):
|
class PikaSwapsNode(PikaNodeBase):
|
||||||
@ -558,7 +558,7 @@ class PikaSwapsNode(PikaNodeBase):
|
|||||||
DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates."
|
DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates."
|
||||||
RETURN_TYPES = ("VIDEO",)
|
RETURN_TYPES = ("VIDEO",)
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
video: VideoInput,
|
video: VideoInput,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
@ -587,11 +587,11 @@ class PikaSwapsNode(PikaNodeBase):
|
|||||||
image_bytes_io = tensor_to_bytesio(image)
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
image_bytes_io.seek(0)
|
image_bytes_io.seek(0)
|
||||||
|
|
||||||
pika_files = [
|
pika_files = {
|
||||||
("video", ("video.mp4", video_bytes_io, "video/mp4")),
|
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||||
("image", ("image.png", image_bytes_io, "image/png")),
|
"image": ("image.png", image_bytes_io, "image/png"),
|
||||||
("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")),
|
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
|
||||||
]
|
}
|
||||||
|
|
||||||
# Prepare non-file data
|
# Prepare non-file data
|
||||||
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||||
@ -613,7 +613,7 @@ class PikaSwapsNode(PikaNodeBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||||
|
|
||||||
|
|
||||||
class PikaffectsNode(PikaNodeBase):
|
class PikaffectsNode(PikaNodeBase):
|
||||||
@ -664,7 +664,7 @@ class PikaffectsNode(PikaNodeBase):
|
|||||||
|
|
||||||
DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear"
|
DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear"
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
pikaffect: str,
|
pikaffect: str,
|
||||||
@ -693,7 +693,7 @@ class PikaffectsNode(PikaNodeBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||||
|
|
||||||
|
|
||||||
class PikaStartEndFrameNode2_2(PikaNodeBase):
|
class PikaStartEndFrameNode2_2(PikaNodeBase):
|
||||||
@ -718,7 +718,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
|||||||
|
|
||||||
DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them."
|
DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them."
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image_start: torch.Tensor,
|
image_start: torch.Tensor,
|
||||||
image_end: torch.Tensor,
|
image_end: torch.Tensor,
|
||||||
@ -732,10 +732,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
|||||||
) -> tuple[VideoFromFile]:
|
) -> tuple[VideoFromFile]:
|
||||||
|
|
||||||
pika_files = [
|
pika_files = [
|
||||||
(
|
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||||
"keyFrames",
|
|
||||||
("image_start.png", tensor_to_bytesio(image_start), "image/png"),
|
|
||||||
),
|
|
||||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -758,7 +755,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
|||||||
from comfy_api.input_impl import VideoFromFile
|
from comfy_api.input_impl import VideoFromFile
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import requests
|
import aiohttp
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ def get_video_url_from_response(
|
|||||||
return str(response.Resp.url)
|
return str(response.Resp.url)
|
||||||
|
|
||||||
|
|
||||||
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
# first, upload image to Pixverse and get image id to use in actual generation call
|
||||||
files = {"image": tensor_to_bytesio(image)}
|
files = {"image": tensor_to_bytesio(image)}
|
||||||
operation = SynchronousOperation(
|
operation = SynchronousOperation(
|
||||||
@ -62,7 +62,7 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
|||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth_kwargs,
|
auth_kwargs=auth_kwargs,
|
||||||
)
|
)
|
||||||
response_upload: PixverseImageUploadResponse = operation.execute()
|
response_upload: PixverseImageUploadResponse = await operation.execute()
|
||||||
|
|
||||||
if response_upload.Resp is None:
|
if response_upload.Resp is None:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@ -164,7 +164,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
@ -205,7 +205,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_api = operation.execute()
|
response_api = await operation.execute()
|
||||||
|
|
||||||
if response_api.Resp is None:
|
if response_api.Resp is None:
|
||||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
@ -229,11 +229,11 @@ class PixverseTextToVideoNode(ComfyNodeABC):
|
|||||||
result_url_extractor=get_video_url_from_response,
|
result_url_extractor=get_video_url_from_response,
|
||||||
estimated_duration=AVERAGE_DURATION_T2V,
|
estimated_duration=AVERAGE_DURATION_T2V,
|
||||||
)
|
)
|
||||||
response_poll = operation.execute()
|
response_poll = await operation.execute()
|
||||||
|
|
||||||
vid_response = requests.get(response_poll.Resp.url)
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(response_poll.Resp.url) as vid_response:
|
||||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||||
|
|
||||||
|
|
||||||
class PixverseImageToVideoNode(ComfyNodeABC):
|
class PixverseImageToVideoNode(ComfyNodeABC):
|
||||||
@ -302,7 +302,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -316,7 +316,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
|
img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs)
|
||||||
|
|
||||||
# 1080p is limited to 5 seconds duration
|
# 1080p is limited to 5 seconds duration
|
||||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
@ -345,7 +345,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_api = operation.execute()
|
response_api = await operation.execute()
|
||||||
|
|
||||||
if response_api.Resp is None:
|
if response_api.Resp is None:
|
||||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
@ -369,10 +369,11 @@ class PixverseImageToVideoNode(ComfyNodeABC):
|
|||||||
result_url_extractor=get_video_url_from_response,
|
result_url_extractor=get_video_url_from_response,
|
||||||
estimated_duration=AVERAGE_DURATION_I2V,
|
estimated_duration=AVERAGE_DURATION_I2V,
|
||||||
)
|
)
|
||||||
response_poll = operation.execute()
|
response_poll = await operation.execute()
|
||||||
|
|
||||||
vid_response = requests.get(response_poll.Resp.url)
|
async with aiohttp.ClientSession() as session:
|
||||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
async with session.get(response_poll.Resp.url) as vid_response:
|
||||||
|
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||||
|
|
||||||
|
|
||||||
class PixverseTransitionVideoNode(ComfyNodeABC):
|
class PixverseTransitionVideoNode(ComfyNodeABC):
|
||||||
@ -436,7 +437,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
first_frame: torch.Tensor,
|
first_frame: torch.Tensor,
|
||||||
last_frame: torch.Tensor,
|
last_frame: torch.Tensor,
|
||||||
@ -450,8 +451,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
|
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
|
||||||
last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
|
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
|
||||||
|
|
||||||
# 1080p is limited to 5 seconds duration
|
# 1080p is limited to 5 seconds duration
|
||||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
@ -480,7 +481,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_api = operation.execute()
|
response_api = await operation.execute()
|
||||||
|
|
||||||
if response_api.Resp is None:
|
if response_api.Resp is None:
|
||||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
@ -504,10 +505,11 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
|
|||||||
result_url_extractor=get_video_url_from_response,
|
result_url_extractor=get_video_url_from_response,
|
||||||
estimated_duration=AVERAGE_DURATION_T2V,
|
estimated_duration=AVERAGE_DURATION_T2V,
|
||||||
)
|
)
|
||||||
response_poll = operation.execute()
|
response_poll = await operation.execute()
|
||||||
|
|
||||||
vid_response = requests.get(response_poll.Resp.url)
|
async with aiohttp.ClientSession() as session:
|
||||||
return (VideoFromFile(BytesIO(vid_response.content)),)
|
async with session.get(response_poll.Resp.url) as vid_response:
|
||||||
|
return (VideoFromFile(BytesIO(await vid_response.content.read())),)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
@ -37,7 +37,7 @@ from io import BytesIO
|
|||||||
from PIL import UnidentifiedImageError
|
from PIL import UnidentifiedImageError
|
||||||
|
|
||||||
|
|
||||||
def handle_recraft_file_request(
|
async def handle_recraft_file_request(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
path: str,
|
path: str,
|
||||||
mask: torch.Tensor=None,
|
mask: torch.Tensor=None,
|
||||||
@ -71,13 +71,13 @@ def handle_recraft_file_request(
|
|||||||
auth_kwargs=auth_kwargs,
|
auth_kwargs=auth_kwargs,
|
||||||
multipart_parser=recraft_multipart_parser,
|
multipart_parser=recraft_multipart_parser,
|
||||||
)
|
)
|
||||||
response: RecraftImageGenerationResponse = operation.execute()
|
response: RecraftImageGenerationResponse = await operation.execute()
|
||||||
all_bytesio = []
|
all_bytesio = []
|
||||||
if response.image is not None:
|
if response.image is not None:
|
||||||
all_bytesio.append(download_url_to_bytesio(response.image.url, timeout=timeout))
|
all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout))
|
||||||
else:
|
else:
|
||||||
for data in response.data:
|
for data in response.data:
|
||||||
all_bytesio.append(download_url_to_bytesio(data.url, timeout=timeout))
|
all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout))
|
||||||
|
|
||||||
return all_bytesio
|
return all_bytesio
|
||||||
|
|
||||||
@ -395,7 +395,7 @@ class RecraftTextToImageNode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
size: str,
|
size: str,
|
||||||
@ -439,7 +439,7 @@ class RecraftTextToImageNode:
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response: RecraftImageGenerationResponse = operation.execute()
|
response: RecraftImageGenerationResponse = await operation.execute()
|
||||||
images = []
|
images = []
|
||||||
urls = []
|
urls = []
|
||||||
for data in response.data:
|
for data in response.data:
|
||||||
@ -451,7 +451,7 @@ class RecraftTextToImageNode:
|
|||||||
f"Result URL: {urls_string}", unique_id
|
f"Result URL: {urls_string}", unique_id
|
||||||
)
|
)
|
||||||
image = bytesio_to_image_tensor(
|
image = bytesio_to_image_tensor(
|
||||||
download_url_to_bytesio(data.url, timeout=1024)
|
await download_url_to_bytesio(data.url, timeout=1024)
|
||||||
)
|
)
|
||||||
if len(image.shape) < 4:
|
if len(image.shape) < 4:
|
||||||
image = image.unsqueeze(0)
|
image = image.unsqueeze(0)
|
||||||
@ -538,7 +538,7 @@ class RecraftImageToImageNode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -578,7 +578,7 @@ class RecraftImageToImageNode:
|
|||||||
total = image.shape[0]
|
total = image.shape[0]
|
||||||
pbar = ProgressBar(total)
|
pbar = ProgressBar(total)
|
||||||
for i in range(total):
|
for i in range(total):
|
||||||
sub_bytes = handle_recraft_file_request(
|
sub_bytes = await handle_recraft_file_request(
|
||||||
image=image[i],
|
image=image[i],
|
||||||
path="/proxy/recraft/images/imageToImage",
|
path="/proxy/recraft/images/imageToImage",
|
||||||
request=request,
|
request=request,
|
||||||
@ -654,7 +654,7 @@ class RecraftImageInpaintingNode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
mask: torch.Tensor,
|
mask: torch.Tensor,
|
||||||
@ -690,7 +690,7 @@ class RecraftImageInpaintingNode:
|
|||||||
total = image.shape[0]
|
total = image.shape[0]
|
||||||
pbar = ProgressBar(total)
|
pbar = ProgressBar(total)
|
||||||
for i in range(total):
|
for i in range(total):
|
||||||
sub_bytes = handle_recraft_file_request(
|
sub_bytes = await handle_recraft_file_request(
|
||||||
image=image[i],
|
image=image[i],
|
||||||
mask=mask[i:i+1],
|
mask=mask[i:i+1],
|
||||||
path="/proxy/recraft/images/inpaint",
|
path="/proxy/recraft/images/inpaint",
|
||||||
@ -779,7 +779,7 @@ class RecraftTextToVectorNode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
substyle: str,
|
substyle: str,
|
||||||
@ -821,7 +821,7 @@ class RecraftTextToVectorNode:
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response: RecraftImageGenerationResponse = operation.execute()
|
response: RecraftImageGenerationResponse = await operation.execute()
|
||||||
svg_data = []
|
svg_data = []
|
||||||
urls = []
|
urls = []
|
||||||
for data in response.data:
|
for data in response.data:
|
||||||
@ -831,7 +831,7 @@ class RecraftTextToVectorNode:
|
|||||||
PromptServer.instance.send_progress_text(
|
PromptServer.instance.send_progress_text(
|
||||||
f"Result URL: {' '.join(urls)}", unique_id
|
f"Result URL: {' '.join(urls)}", unique_id
|
||||||
)
|
)
|
||||||
svg_data.append(download_url_to_bytesio(data.url, timeout=1024))
|
svg_data.append(await download_url_to_bytesio(data.url, timeout=1024))
|
||||||
|
|
||||||
return (SVG(svg_data),)
|
return (SVG(svg_data),)
|
||||||
|
|
||||||
@ -861,7 +861,7 @@ class RecraftVectorizeImageNode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -870,7 +870,7 @@ class RecraftVectorizeImageNode:
|
|||||||
total = image.shape[0]
|
total = image.shape[0]
|
||||||
pbar = ProgressBar(total)
|
pbar = ProgressBar(total)
|
||||||
for i in range(total):
|
for i in range(total):
|
||||||
sub_bytes = handle_recraft_file_request(
|
sub_bytes = await handle_recraft_file_request(
|
||||||
image=image[i],
|
image=image[i],
|
||||||
path="/proxy/recraft/images/vectorize",
|
path="/proxy/recraft/images/vectorize",
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
@ -942,7 +942,7 @@ class RecraftReplaceBackgroundNode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -973,7 +973,7 @@ class RecraftReplaceBackgroundNode:
|
|||||||
total = image.shape[0]
|
total = image.shape[0]
|
||||||
pbar = ProgressBar(total)
|
pbar = ProgressBar(total)
|
||||||
for i in range(total):
|
for i in range(total):
|
||||||
sub_bytes = handle_recraft_file_request(
|
sub_bytes = await handle_recraft_file_request(
|
||||||
image=image[i],
|
image=image[i],
|
||||||
path="/proxy/recraft/images/replaceBackground",
|
path="/proxy/recraft/images/replaceBackground",
|
||||||
request=request,
|
request=request,
|
||||||
@ -1011,7 +1011,7 @@ class RecraftRemoveBackgroundNode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -1020,7 +1020,7 @@ class RecraftRemoveBackgroundNode:
|
|||||||
total = image.shape[0]
|
total = image.shape[0]
|
||||||
pbar = ProgressBar(total)
|
pbar = ProgressBar(total)
|
||||||
for i in range(total):
|
for i in range(total):
|
||||||
sub_bytes = handle_recraft_file_request(
|
sub_bytes = await handle_recraft_file_request(
|
||||||
image=image[i],
|
image=image[i],
|
||||||
path="/proxy/recraft/images/removeBackground",
|
path="/proxy/recraft/images/removeBackground",
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
@ -1062,7 +1062,7 @@ class RecraftCrispUpscaleNode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -1071,7 +1071,7 @@ class RecraftCrispUpscaleNode:
|
|||||||
total = image.shape[0]
|
total = image.shape[0]
|
||||||
pbar = ProgressBar(total)
|
pbar = ProgressBar(total)
|
||||||
for i in range(total):
|
for i in range(total):
|
||||||
sub_bytes = handle_recraft_file_request(
|
sub_bytes = await handle_recraft_file_request(
|
||||||
image=image[i],
|
image=image[i],
|
||||||
path=self.RECRAFT_PATH,
|
path=self.RECRAFT_PATH,
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
|
|||||||
@ -9,11 +9,10 @@ from __future__ import annotations
|
|||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
from comfy.comfy_types.node_typing import IO
|
from comfy.comfy_types.node_typing import IO
|
||||||
from comfy.cmd import folder_paths as comfy_paths
|
from comfy.cmd import folder_paths as comfy_paths
|
||||||
import requests
|
import aiohttp
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
import shutil
|
import asyncio
|
||||||
import time
|
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@ -66,7 +65,6 @@ def create_task_error(response: Rodin3DGenerateResponse):
|
|||||||
return hasattr(response, "error")
|
return hasattr(response, "error")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Rodin3DAPI:
|
class Rodin3DAPI:
|
||||||
"""
|
"""
|
||||||
Generate 3D Assets using Rodin API
|
Generate 3D Assets using Rodin API
|
||||||
@ -123,8 +121,8 @@ class Rodin3DAPI:
|
|||||||
else:
|
else:
|
||||||
return "Generating"
|
return "Generating"
|
||||||
|
|
||||||
def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
|
async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
|
||||||
if images == None:
|
if images is None:
|
||||||
raise Exception("Rodin 3D generate requires at least 1 image.")
|
raise Exception("Rodin 3D generate requires at least 1 image.")
|
||||||
if len(images) >= 5:
|
if len(images) >= 5:
|
||||||
raise Exception("Rodin 3D generate requires up to 5 image.")
|
raise Exception("Rodin 3D generate requires up to 5 image.")
|
||||||
@ -155,7 +153,7 @@ class Rodin3DAPI:
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = operation.execute()
|
response = await operation.execute()
|
||||||
|
|
||||||
if create_task_error(response):
|
if create_task_error(response):
|
||||||
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
||||||
@ -168,7 +166,7 @@ class Rodin3DAPI:
|
|||||||
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
|
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
|
||||||
return task_uuid, subscription_key
|
return task_uuid, subscription_key
|
||||||
|
|
||||||
def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
|
async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
|
||||||
|
|
||||||
path = "/proxy/rodin/api/v2/status"
|
path = "/proxy/rodin/api/v2/status"
|
||||||
|
|
||||||
@ -191,11 +189,9 @@ class Rodin3DAPI:
|
|||||||
|
|
||||||
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
||||||
|
|
||||||
return poll_operation.execute()
|
return await poll_operation.execute()
|
||||||
|
|
||||||
|
async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
|
||||||
|
|
||||||
def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
|
|
||||||
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||||
|
|
||||||
path = "/proxy/rodin/api/v2/download"
|
path = "/proxy/rodin/api/v2/download"
|
||||||
@ -212,53 +208,59 @@ class Rodin3DAPI:
|
|||||||
auth_kwargs=kwargs
|
auth_kwargs=kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
return operation.execute()
|
return await operation.execute()
|
||||||
|
|
||||||
def GetQualityAndMode(self, PolyCount):
|
def get_quality_mode(self, poly_count):
|
||||||
if PolyCount == "200K-Triangle":
|
if poly_count == "200K-Triangle":
|
||||||
mesh_mode = "Raw"
|
mesh_mode = "Raw"
|
||||||
quality = "medium"
|
quality = "medium"
|
||||||
else:
|
else:
|
||||||
mesh_mode = "Quad"
|
mesh_mode = "Quad"
|
||||||
if PolyCount == "4K-Quad":
|
if poly_count == "4K-Quad":
|
||||||
quality = "extra-low"
|
quality = "extra-low"
|
||||||
elif PolyCount == "8K-Quad":
|
elif poly_count == "8K-Quad":
|
||||||
quality = "low"
|
quality = "low"
|
||||||
elif PolyCount == "18K-Quad":
|
elif poly_count == "18K-Quad":
|
||||||
quality = "medium"
|
quality = "medium"
|
||||||
elif PolyCount == "50K-Quad":
|
elif poly_count == "50K-Quad":
|
||||||
quality = "high"
|
quality = "high"
|
||||||
else:
|
else:
|
||||||
quality = "medium"
|
quality = "medium"
|
||||||
|
|
||||||
return mesh_mode, quality
|
return mesh_mode, quality
|
||||||
|
|
||||||
def DownLoadFiles(self, Url_List):
|
async def download_files(self, url_list):
|
||||||
Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||||
os.makedirs(Save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
model_file_path = None
|
model_file_path = None
|
||||||
for Item in Url_List.list:
|
async with aiohttp.ClientSession() as session:
|
||||||
url = Item.url
|
for i in url_list.list:
|
||||||
file_name = Item.name
|
url = i.url
|
||||||
file_path = os.path.join(Save_path, file_name)
|
file_name = i.name
|
||||||
if file_path.endswith(".glb"):
|
file_path = os.path.join(save_path, file_name)
|
||||||
model_file_path = file_path
|
if file_path.endswith(".glb"):
|
||||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
model_file_path = file_path
|
||||||
max_retries = 5
|
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
||||||
for attempt in range(max_retries):
|
max_retries = 5
|
||||||
try:
|
for attempt in range(max_retries):
|
||||||
with requests.get(url, stream=True) as r:
|
try:
|
||||||
r.raise_for_status()
|
async with session.get(url) as resp:
|
||||||
with open(file_path, "wb") as f:
|
resp.raise_for_status()
|
||||||
shutil.copyfileobj(r.raw, f)
|
with open(file_path, "wb") as f:
|
||||||
break
|
async for chunk in resp.content.iter_chunked(32 * 1024):
|
||||||
except Exception as e:
|
f.write(chunk)
|
||||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
break
|
||||||
if attempt < max_retries - 1:
|
except Exception as e:
|
||||||
logging.info("Retrying...")
|
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
||||||
time.sleep(2)
|
if attempt < max_retries - 1:
|
||||||
else:
|
logging.info("Retrying...")
|
||||||
logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.")
|
await asyncio.sleep(2)
|
||||||
|
else:
|
||||||
|
logging.info(
|
||||||
|
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
|
||||||
|
file_path,
|
||||||
|
max_retries,
|
||||||
|
)
|
||||||
|
|
||||||
return model_file_path
|
return model_file_path
|
||||||
|
|
||||||
@ -285,7 +287,7 @@ class Rodin3D_Regular(Rodin3DAPI):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
Images,
|
Images,
|
||||||
Seed,
|
Seed,
|
||||||
@ -298,14 +300,17 @@ class Rodin3D_Regular(Rodin3DAPI):
|
|||||||
m_images = []
|
m_images = []
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||||
self.poll_for_task_status(subscription_key, **kwargs)
|
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
**kwargs)
|
||||||
model = self.DownLoadFiles(Download_List)
|
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||||
|
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||||
|
model = await self.download_files(download_list)
|
||||||
|
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
|
|
||||||
class Rodin3D_Detail(Rodin3DAPI):
|
class Rodin3D_Detail(Rodin3DAPI):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -328,7 +333,7 @@ class Rodin3D_Detail(Rodin3DAPI):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
Images,
|
Images,
|
||||||
Seed,
|
Seed,
|
||||||
@ -341,14 +346,17 @@ class Rodin3D_Detail(Rodin3DAPI):
|
|||||||
m_images = []
|
m_images = []
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||||
self.poll_for_task_status(subscription_key, **kwargs)
|
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
**kwargs)
|
||||||
model = self.DownLoadFiles(Download_List)
|
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||||
|
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||||
|
model = await self.download_files(download_list)
|
||||||
|
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
|
|
||||||
class Rodin3D_Smooth(Rodin3DAPI):
|
class Rodin3D_Smooth(Rodin3DAPI):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -371,7 +379,7 @@ class Rodin3D_Smooth(Rodin3DAPI):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
Images,
|
Images,
|
||||||
Seed,
|
Seed,
|
||||||
@ -384,14 +392,17 @@ class Rodin3D_Smooth(Rodin3DAPI):
|
|||||||
m_images = []
|
m_images = []
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
|
mesh_mode, quality = self.get_quality_mode(Polygon_count)
|
||||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
|
||||||
self.poll_for_task_status(subscription_key, **kwargs)
|
quality=quality, tier=tier, mesh_mode=mesh_mode,
|
||||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
**kwargs)
|
||||||
model = self.DownLoadFiles(Download_List)
|
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||||
|
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||||
|
model = await self.download_files(download_list)
|
||||||
|
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
|
|
||||||
class Rodin3D_Sketch(Rodin3DAPI):
|
class Rodin3D_Sketch(Rodin3DAPI):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -423,7 +434,7 @@ class Rodin3D_Sketch(Rodin3DAPI):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
Images,
|
Images,
|
||||||
Seed,
|
Seed,
|
||||||
@ -437,10 +448,12 @@ class Rodin3D_Sketch(Rodin3DAPI):
|
|||||||
material_type = "PBR"
|
material_type = "PBR"
|
||||||
quality = "medium"
|
quality = "medium"
|
||||||
mesh_mode = "Quad"
|
mesh_mode = "Quad"
|
||||||
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
|
task_uuid, subscription_key = await self.create_generate_task(
|
||||||
self.poll_for_task_status(subscription_key, **kwargs)
|
images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs
|
||||||
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
|
)
|
||||||
model = self.DownLoadFiles(Download_List)
|
await self.poll_for_task_status(subscription_key, **kwargs)
|
||||||
|
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
|
||||||
|
model = await self.download_files(download_list)
|
||||||
|
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
|
|||||||
@ -99,14 +99,14 @@ def validate_input_image(image: torch.Tensor) -> bool:
|
|||||||
return image.shape[2] < 8000 and image.shape[1] < 8000
|
return image.shape[2] < 8000 and image.shape[1] < 8000
|
||||||
|
|
||||||
|
|
||||||
def poll_until_finished(
|
async def poll_until_finished(
|
||||||
auth_kwargs: dict[str, str],
|
auth_kwargs: dict[str, str],
|
||||||
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: Optional[int] = None,
|
||||||
node_id: Optional[str] = None,
|
node_id: Optional[str] = None,
|
||||||
) -> TaskStatusResponse:
|
) -> TaskStatusResponse:
|
||||||
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
|
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
|
||||||
return PollingOperation(
|
return await PollingOperation(
|
||||||
poll_endpoint=api_endpoint,
|
poll_endpoint=api_endpoint,
|
||||||
completed_statuses=[
|
completed_statuses=[
|
||||||
TaskStatus.SUCCEEDED.value,
|
TaskStatus.SUCCEEDED.value,
|
||||||
@ -115,7 +115,7 @@ def poll_until_finished(
|
|||||||
TaskStatus.FAILED.value,
|
TaskStatus.FAILED.value,
|
||||||
TaskStatus.CANCELLED.value,
|
TaskStatus.CANCELLED.value,
|
||||||
],
|
],
|
||||||
status_extractor=lambda response: (response.status.value),
|
status_extractor=lambda response: response.status.value,
|
||||||
auth_kwargs=auth_kwargs,
|
auth_kwargs=auth_kwargs,
|
||||||
result_url_extractor=get_video_url_from_task_status,
|
result_url_extractor=get_video_url_from_task_status,
|
||||||
estimated_duration=estimated_duration,
|
estimated_duration=estimated_duration,
|
||||||
@ -167,11 +167,11 @@ class RunwayVideoGenNode(ComfyNodeABC):
|
|||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_response(
|
async def get_response(
|
||||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
) -> RunwayImageToVideoResponse:
|
) -> RunwayImageToVideoResponse:
|
||||||
"""Poll the task status until it is finished then get the response."""
|
"""Poll the task status until it is finished then get the response."""
|
||||||
return poll_until_finished(
|
return await poll_until_finished(
|
||||||
auth_kwargs,
|
auth_kwargs,
|
||||||
ApiEndpoint(
|
ApiEndpoint(
|
||||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||||
@ -183,7 +183,7 @@ class RunwayVideoGenNode(ComfyNodeABC):
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_video(
|
async def generate_video(
|
||||||
self,
|
self,
|
||||||
request: RunwayImageToVideoRequest,
|
request: RunwayImageToVideoRequest,
|
||||||
auth_kwargs: dict[str, str],
|
auth_kwargs: dict[str, str],
|
||||||
@ -200,15 +200,15 @@ class RunwayVideoGenNode(ComfyNodeABC):
|
|||||||
auth_kwargs=auth_kwargs,
|
auth_kwargs=auth_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_response = initial_operation.execute()
|
initial_response = await initial_operation.execute()
|
||||||
self.validate_task_created(initial_response)
|
self.validate_task_created(initial_response)
|
||||||
task_id = initial_response.id
|
task_id = initial_response.id
|
||||||
|
|
||||||
final_response = self.get_response(task_id, auth_kwargs, node_id)
|
final_response = await self.get_response(task_id, auth_kwargs, node_id)
|
||||||
self.validate_response(final_response)
|
self.validate_response(final_response)
|
||||||
|
|
||||||
video_url = get_video_url_from_task_status(final_response)
|
video_url = get_video_url_from_task_status(final_response)
|
||||||
return (download_url_to_video_output(video_url),)
|
return (await download_url_to_video_output(video_url),)
|
||||||
|
|
||||||
|
|
||||||
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
||||||
@ -250,7 +250,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
start_frame: torch.Tensor,
|
start_frame: torch.Tensor,
|
||||||
@ -265,7 +265,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
|||||||
validate_input_image(start_frame)
|
validate_input_image(start_frame)
|
||||||
|
|
||||||
# Upload image
|
# Upload image
|
||||||
download_urls = upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
start_frame,
|
start_frame,
|
||||||
max_images=1,
|
max_images=1,
|
||||||
mime_type="image/png",
|
mime_type="image/png",
|
||||||
@ -274,7 +274,7 @@ class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
|
|||||||
if len(download_urls) != 1:
|
if len(download_urls) != 1:
|
||||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||||
|
|
||||||
return self.generate_video(
|
return await self.generate_video(
|
||||||
RunwayImageToVideoRequest(
|
RunwayImageToVideoRequest(
|
||||||
promptText=prompt,
|
promptText=prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -333,7 +333,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
start_frame: torch.Tensor,
|
start_frame: torch.Tensor,
|
||||||
@ -348,7 +348,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
|||||||
validate_input_image(start_frame)
|
validate_input_image(start_frame)
|
||||||
|
|
||||||
# Upload image
|
# Upload image
|
||||||
download_urls = upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
start_frame,
|
start_frame,
|
||||||
max_images=1,
|
max_images=1,
|
||||||
mime_type="image/png",
|
mime_type="image/png",
|
||||||
@ -357,7 +357,7 @@ class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
|
|||||||
if len(download_urls) != 1:
|
if len(download_urls) != 1:
|
||||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||||
|
|
||||||
return self.generate_video(
|
return await self.generate_video(
|
||||||
RunwayImageToVideoRequest(
|
RunwayImageToVideoRequest(
|
||||||
promptText=prompt,
|
promptText=prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -382,10 +382,10 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
|||||||
|
|
||||||
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
|
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
|
||||||
|
|
||||||
def get_response(
|
async def get_response(
|
||||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
) -> RunwayImageToVideoResponse:
|
) -> RunwayImageToVideoResponse:
|
||||||
return poll_until_finished(
|
return await poll_until_finished(
|
||||||
auth_kwargs,
|
auth_kwargs,
|
||||||
ApiEndpoint(
|
ApiEndpoint(
|
||||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||||
@ -437,7 +437,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
start_frame: torch.Tensor,
|
start_frame: torch.Tensor,
|
||||||
@ -455,7 +455,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
|||||||
|
|
||||||
# Upload images
|
# Upload images
|
||||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||||
download_urls = upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
stacked_input_images,
|
stacked_input_images,
|
||||||
max_images=2,
|
max_images=2,
|
||||||
mime_type="image/png",
|
mime_type="image/png",
|
||||||
@ -464,7 +464,7 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode):
|
|||||||
if len(download_urls) != 2:
|
if len(download_urls) != 2:
|
||||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||||
|
|
||||||
return self.generate_video(
|
return await self.generate_video(
|
||||||
RunwayImageToVideoRequest(
|
RunwayImageToVideoRequest(
|
||||||
promptText=prompt,
|
promptText=prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -543,11 +543,11 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
|||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def get_response(
|
async def get_response(
|
||||||
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||||
) -> TaskStatusResponse:
|
) -> TaskStatusResponse:
|
||||||
"""Poll the task status until it is finished then get the response."""
|
"""Poll the task status until it is finished then get the response."""
|
||||||
return poll_until_finished(
|
return await poll_until_finished(
|
||||||
auth_kwargs,
|
auth_kwargs,
|
||||||
ApiEndpoint(
|
ApiEndpoint(
|
||||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||||
@ -559,7 +559,7 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
|||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def api_call(
|
async def api_call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
ratio: str,
|
ratio: str,
|
||||||
@ -574,7 +574,7 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
|||||||
reference_images = None
|
reference_images = None
|
||||||
if reference_image is not None:
|
if reference_image is not None:
|
||||||
validate_input_image(reference_image)
|
validate_input_image(reference_image)
|
||||||
download_urls = upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
reference_image,
|
reference_image,
|
||||||
max_images=1,
|
max_images=1,
|
||||||
mime_type="image/png",
|
mime_type="image/png",
|
||||||
@ -605,19 +605,19 @@ class RunwayTextToImageNode(ComfyNodeABC):
|
|||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_response = initial_operation.execute()
|
initial_response = await initial_operation.execute()
|
||||||
self.validate_task_created(initial_response)
|
self.validate_task_created(initial_response)
|
||||||
task_id = initial_response.id
|
task_id = initial_response.id
|
||||||
|
|
||||||
# Poll for completion
|
# Poll for completion
|
||||||
final_response = self.get_response(
|
final_response = await self.get_response(
|
||||||
task_id, auth_kwargs=kwargs, node_id=unique_id
|
task_id, auth_kwargs=kwargs, node_id=unique_id
|
||||||
)
|
)
|
||||||
self.validate_response(final_response)
|
self.validate_response(final_response)
|
||||||
|
|
||||||
# Download and return image
|
# Download and return image
|
||||||
image_url = get_image_url_from_task_status(final_response)
|
image_url = get_image_url_from_task_status(final_response)
|
||||||
return (download_url_to_image_tensor(image_url),)
|
return (await download_url_to_image_tensor(image_url),)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
@ -124,7 +124,7 @@ class StabilityStableImageUltraNode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
|
async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
|
||||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
@ -163,7 +163,7 @@ class StabilityStableImageUltraNode:
|
|||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_api = operation.execute()
|
response_api = await operation.execute()
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||||
@ -257,7 +257,7 @@ class StabilityStableImageSD_3_5Node:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
|
async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
|
||||||
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
@ -302,7 +302,7 @@ class StabilityStableImageSD_3_5Node:
|
|||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_api = operation.execute()
|
response_api = await operation.execute()
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||||
@ -374,7 +374,7 @@ class StabilityUpscaleConservativeNode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
|
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||||
@ -403,7 +403,7 @@ class StabilityUpscaleConservativeNode:
|
|||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_api = operation.execute()
|
response_api = await operation.execute()
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||||
@ -480,7 +480,7 @@ class StabilityUpscaleCreativeNode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
|
async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||||
@ -512,7 +512,7 @@ class StabilityUpscaleCreativeNode:
|
|||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_api = operation.execute()
|
response_api = await operation.execute()
|
||||||
|
|
||||||
operation = PollingOperation(
|
operation = PollingOperation(
|
||||||
poll_endpoint=ApiEndpoint(
|
poll_endpoint=ApiEndpoint(
|
||||||
@ -527,7 +527,7 @@ class StabilityUpscaleCreativeNode:
|
|||||||
status_extractor=lambda x: get_async_dummy_status(x),
|
status_extractor=lambda x: get_async_dummy_status(x),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_poll: StabilityResultsGetResponse = operation.execute()
|
response_poll: StabilityResultsGetResponse = await operation.execute()
|
||||||
|
|
||||||
if response_poll.finish_reason != "SUCCESS":
|
if response_poll.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||||
@ -563,8 +563,7 @@ class StabilityUpscaleFastNode:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def api_call(self, image: torch.Tensor,
|
async def api_call(self, image: torch.Tensor, **kwargs):
|
||||||
**kwargs):
|
|
||||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||||
|
|
||||||
files = {
|
files = {
|
||||||
@ -583,7 +582,7 @@ class StabilityUpscaleFastNode:
|
|||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
response_api = operation.execute()
|
response_api = await operation.execute()
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||||
|
|||||||
@ -37,8 +37,8 @@ from comfy_api_nodes.apinode_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def upload_image_to_tripo(image, **kwargs):
|
async def upload_image_to_tripo(image, **kwargs):
|
||||||
urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
|
urls = await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
|
||||||
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
|
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
|
||||||
|
|
||||||
def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
||||||
@ -49,7 +49,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str:
|
|||||||
raise RuntimeError(f"Failed to get model url from response: {response}")
|
raise RuntimeError(f"Failed to get model url from response: {response}")
|
||||||
|
|
||||||
|
|
||||||
def poll_until_finished(
|
async def poll_until_finished(
|
||||||
kwargs: dict[str, str],
|
kwargs: dict[str, str],
|
||||||
response: TripoTaskResponse,
|
response: TripoTaskResponse,
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
@ -57,7 +57,7 @@ def poll_until_finished(
|
|||||||
if response.code != 0:
|
if response.code != 0:
|
||||||
raise RuntimeError(f"Failed to generate mesh: {response.error}")
|
raise RuntimeError(f"Failed to generate mesh: {response.error}")
|
||||||
task_id = response.data.task_id
|
task_id = response.data.task_id
|
||||||
response_poll = PollingOperation(
|
response_poll = await PollingOperation(
|
||||||
poll_endpoint=ApiEndpoint(
|
poll_endpoint=ApiEndpoint(
|
||||||
path=f"/proxy/tripo/v2/openapi/task/{task_id}",
|
path=f"/proxy/tripo/v2/openapi/task/{task_id}",
|
||||||
method=HttpMethod.GET,
|
method=HttpMethod.GET,
|
||||||
@ -80,7 +80,7 @@ def poll_until_finished(
|
|||||||
).execute()
|
).execute()
|
||||||
if response_poll.data.status == TripoTaskStatus.SUCCESS:
|
if response_poll.data.status == TripoTaskStatus.SUCCESS:
|
||||||
url = get_model_url_from_response(response_poll)
|
url = get_model_url_from_response(response_poll)
|
||||||
bytesio = download_url_to_bytesio(url)
|
bytesio = await download_url_to_bytesio(url)
|
||||||
# Save the downloaded model file
|
# Save the downloaded model file
|
||||||
model_file = f"tripo_model_{task_id}.glb"
|
model_file = f"tripo_model_{task_id}.glb"
|
||||||
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
|
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
|
||||||
@ -88,6 +88,7 @@ def poll_until_finished(
|
|||||||
return model_file, task_id
|
return model_file, task_id
|
||||||
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
|
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
|
||||||
|
|
||||||
|
|
||||||
class TripoTextToModelNode:
|
class TripoTextToModelNode:
|
||||||
"""
|
"""
|
||||||
Generates 3D models synchronously based on a text prompt using Tripo's API.
|
Generates 3D models synchronously based on a text prompt using Tripo's API.
|
||||||
@ -126,11 +127,11 @@ class TripoTextToModelNode:
|
|||||||
API_NODE = True
|
API_NODE = True
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||||
style_enum = None if style == "None" else style
|
style_enum = None if style == "None" else style
|
||||||
if not prompt:
|
if not prompt:
|
||||||
raise RuntimeError("Prompt is required")
|
raise RuntimeError("Prompt is required")
|
||||||
response = SynchronousOperation(
|
response = await SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path="/proxy/tripo/v2/openapi/task",
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
@ -155,7 +156,8 @@ class TripoTextToModelNode:
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
).execute()
|
).execute()
|
||||||
return poll_until_finished(kwargs, response)
|
return await poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
|
||||||
class TripoImageToModelNode:
|
class TripoImageToModelNode:
|
||||||
"""
|
"""
|
||||||
@ -195,12 +197,12 @@ class TripoImageToModelNode:
|
|||||||
API_NODE = True
|
API_NODE = True
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
async def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
|
||||||
style_enum = None if style == "None" else style
|
style_enum = None if style == "None" else style
|
||||||
if image is None:
|
if image is None:
|
||||||
raise RuntimeError("Image is required")
|
raise RuntimeError("Image is required")
|
||||||
tripo_file = upload_image_to_tripo(image, **kwargs)
|
tripo_file = await upload_image_to_tripo(image, **kwargs)
|
||||||
response = SynchronousOperation(
|
response = await SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path="/proxy/tripo/v2/openapi/task",
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
@ -225,7 +227,8 @@ class TripoImageToModelNode:
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
).execute()
|
).execute()
|
||||||
return poll_until_finished(kwargs, response)
|
return await poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
|
||||||
class TripoMultiviewToModelNode:
|
class TripoMultiviewToModelNode:
|
||||||
"""
|
"""
|
||||||
@ -267,7 +270,7 @@ class TripoMultiviewToModelNode:
|
|||||||
API_NODE = True
|
API_NODE = True
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
|
async def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
|
||||||
if image is None:
|
if image is None:
|
||||||
raise RuntimeError("front image for multiview is required")
|
raise RuntimeError("front image for multiview is required")
|
||||||
images = []
|
images = []
|
||||||
@ -282,11 +285,11 @@ class TripoMultiviewToModelNode:
|
|||||||
for image_name in ["image", "image_left", "image_back", "image_right"]:
|
for image_name in ["image", "image_left", "image_back", "image_right"]:
|
||||||
image_ = image_dict[image_name]
|
image_ = image_dict[image_name]
|
||||||
if image_ is not None:
|
if image_ is not None:
|
||||||
tripo_file = upload_image_to_tripo(image_, **kwargs)
|
tripo_file = await upload_image_to_tripo(image_, **kwargs)
|
||||||
images.append(tripo_file)
|
images.append(tripo_file)
|
||||||
else:
|
else:
|
||||||
images.append(TripoFileEmptyReference())
|
images.append(TripoFileEmptyReference())
|
||||||
response = SynchronousOperation(
|
response = await SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path="/proxy/tripo/v2/openapi/task",
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
@ -309,7 +312,8 @@ class TripoMultiviewToModelNode:
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
).execute()
|
).execute()
|
||||||
return poll_until_finished(kwargs, response)
|
return await poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
|
||||||
class TripoTextureNode:
|
class TripoTextureNode:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -340,8 +344,8 @@ class TripoTextureNode:
|
|||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
AVERAGE_DURATION = 80
|
AVERAGE_DURATION = 80
|
||||||
|
|
||||||
def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
|
async def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
|
||||||
response = SynchronousOperation(
|
response = await SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path="/proxy/tripo/v2/openapi/task",
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
@ -358,7 +362,7 @@ class TripoTextureNode:
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
).execute()
|
).execute()
|
||||||
return poll_until_finished(kwargs, response)
|
return await poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
|
||||||
class TripoRefineNode:
|
class TripoRefineNode:
|
||||||
@ -387,8 +391,8 @@ class TripoRefineNode:
|
|||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
AVERAGE_DURATION = 240
|
AVERAGE_DURATION = 240
|
||||||
|
|
||||||
def generate_mesh(self, model_task_id, **kwargs):
|
async def generate_mesh(self, model_task_id, **kwargs):
|
||||||
response = SynchronousOperation(
|
response = await SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path="/proxy/tripo/v2/openapi/task",
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
@ -400,7 +404,7 @@ class TripoRefineNode:
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
).execute()
|
).execute()
|
||||||
return poll_until_finished(kwargs, response)
|
return await poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
|
||||||
class TripoRigNode:
|
class TripoRigNode:
|
||||||
@ -425,8 +429,8 @@ class TripoRigNode:
|
|||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
AVERAGE_DURATION = 180
|
AVERAGE_DURATION = 180
|
||||||
|
|
||||||
def generate_mesh(self, original_model_task_id, **kwargs):
|
async def generate_mesh(self, original_model_task_id, **kwargs):
|
||||||
response = SynchronousOperation(
|
response = await SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path="/proxy/tripo/v2/openapi/task",
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
@ -440,7 +444,8 @@ class TripoRigNode:
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
).execute()
|
).execute()
|
||||||
return poll_until_finished(kwargs, response)
|
return await poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
|
||||||
class TripoRetargetNode:
|
class TripoRetargetNode:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -475,8 +480,8 @@ class TripoRetargetNode:
|
|||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
AVERAGE_DURATION = 30
|
AVERAGE_DURATION = 30
|
||||||
|
|
||||||
def generate_mesh(self, animation, original_model_task_id, **kwargs):
|
async def generate_mesh(self, animation, original_model_task_id, **kwargs):
|
||||||
response = SynchronousOperation(
|
response = await SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path="/proxy/tripo/v2/openapi/task",
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
@ -491,7 +496,8 @@ class TripoRetargetNode:
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
).execute()
|
).execute()
|
||||||
return poll_until_finished(kwargs, response)
|
return await poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
|
||||||
class TripoConversionNode:
|
class TripoConversionNode:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -529,10 +535,10 @@ class TripoConversionNode:
|
|||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
AVERAGE_DURATION = 30
|
AVERAGE_DURATION = 30
|
||||||
|
|
||||||
def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
|
async def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
|
||||||
if not original_model_task_id:
|
if not original_model_task_id:
|
||||||
raise RuntimeError("original_model_task_id is required")
|
raise RuntimeError("original_model_task_id is required")
|
||||||
response = SynchronousOperation(
|
response = await SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
path="/proxy/tripo/v2/openapi/task",
|
path="/proxy/tripo/v2/openapi/task",
|
||||||
method=HttpMethod.POST,
|
method=HttpMethod.POST,
|
||||||
@ -549,7 +555,8 @@ class TripoConversionNode:
|
|||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=kwargs,
|
||||||
).execute()
|
).execute()
|
||||||
return poll_until_finished(kwargs, response)
|
return await poll_until_finished(kwargs, response)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"TripoTextToModelNode": TripoTextToModelNode,
|
"TripoTextToModelNode": TripoTextToModelNode,
|
||||||
|
|||||||
@ -1,17 +1,18 @@
|
|||||||
import io
|
|
||||||
import logging
|
import logging
|
||||||
import base64
|
import base64
|
||||||
import requests
|
import aiohttp
|
||||||
import torch
|
import torch
|
||||||
|
from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
|
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||||
from comfy_api.input_impl.video_types import VideoFromFile
|
from comfy_api.input_impl.video_types import VideoFromFile
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
VeoGenVidRequest,
|
VeoGenVidRequest,
|
||||||
VeoGenVidResponse,
|
VeoGenVidResponse,
|
||||||
VeoGenVidPollRequest,
|
VeoGenVidPollRequest,
|
||||||
VeoGenVidPollResponse
|
VeoGenVidPollResponse,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.apis.client import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
@ -22,7 +23,7 @@ from comfy_api_nodes.apis.client import (
|
|||||||
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
from comfy_api_nodes.apinode_utils import (
|
||||||
downscale_image_tensor,
|
downscale_image_tensor,
|
||||||
tensor_to_base64_string
|
tensor_to_base64_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
AVERAGE_DURATION_VIDEO_GEN = 32
|
AVERAGE_DURATION_VIDEO_GEN = 32
|
||||||
@ -50,7 +51,7 @@ def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optiona
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class VeoVideoGenerationNode(ComfyNodeABC):
|
class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Generates videos from text prompts using Google's Veo API.
|
Generates videos from text prompts using Google's Veo API.
|
||||||
|
|
||||||
@ -59,101 +60,93 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return comfy_io.Schema(
|
||||||
"required": {
|
node_id="VeoVideoGenerationNode",
|
||||||
"prompt": (
|
display_name="Google Veo 2 Video Generation",
|
||||||
IO.STRING,
|
category="api node/video/Veo",
|
||||||
{
|
description="Generates videos from text prompts using Google's Veo 2 API",
|
||||||
"multiline": True,
|
inputs=[
|
||||||
"default": "",
|
comfy_io.String.Input(
|
||||||
"tooltip": "Text description of the video",
|
"prompt",
|
||||||
},
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text description of the video",
|
||||||
),
|
),
|
||||||
"aspect_ratio": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"aspect_ratio",
|
||||||
{
|
options=["16:9", "9:16"],
|
||||||
"options": ["16:9", "9:16"],
|
default="16:9",
|
||||||
"default": "16:9",
|
tooltip="Aspect ratio of the output video",
|
||||||
"tooltip": "Aspect ratio of the output video",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
},
|
comfy_io.String.Input(
|
||||||
"optional": {
|
"negative_prompt",
|
||||||
"negative_prompt": (
|
multiline=True,
|
||||||
IO.STRING,
|
default="",
|
||||||
{
|
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||||
"multiline": True,
|
optional=True,
|
||||||
"default": "",
|
|
||||||
"tooltip": "Negative text prompt to guide what to avoid in the video",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"duration_seconds": (
|
comfy_io.Int.Input(
|
||||||
IO.INT,
|
"duration_seconds",
|
||||||
{
|
default=5,
|
||||||
"default": 5,
|
min=5,
|
||||||
"min": 5,
|
max=8,
|
||||||
"max": 8,
|
step=1,
|
||||||
"step": 1,
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
"display": "number",
|
tooltip="Duration of the output video in seconds",
|
||||||
"tooltip": "Duration of the output video in seconds",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"enhance_prompt": (
|
comfy_io.Boolean.Input(
|
||||||
IO.BOOLEAN,
|
"enhance_prompt",
|
||||||
{
|
default=True,
|
||||||
"default": True,
|
tooltip="Whether to enhance the prompt with AI assistance",
|
||||||
"tooltip": "Whether to enhance the prompt with AI assistance",
|
optional=True,
|
||||||
}
|
|
||||||
),
|
),
|
||||||
"person_generation": (
|
comfy_io.Combo.Input(
|
||||||
IO.COMBO,
|
"person_generation",
|
||||||
{
|
options=["ALLOW", "BLOCK"],
|
||||||
"options": ["ALLOW", "BLOCK"],
|
default="ALLOW",
|
||||||
"default": "ALLOW",
|
tooltip="Whether to allow generating people in the video",
|
||||||
"tooltip": "Whether to allow generating people in the video",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"seed": (
|
comfy_io.Int.Input(
|
||||||
IO.INT,
|
"seed",
|
||||||
{
|
default=0,
|
||||||
"default": 0,
|
min=0,
|
||||||
"min": 0,
|
max=0xFFFFFFFF,
|
||||||
"max": 0xFFFFFFFF,
|
step=1,
|
||||||
"step": 1,
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
"display": "number",
|
control_after_generate=True,
|
||||||
"control_after_generate": True,
|
tooltip="Seed for video generation (0 for random)",
|
||||||
"tooltip": "Seed for video generation (0 for random)",
|
optional=True,
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"image": (IO.IMAGE, {
|
comfy_io.Image.Input(
|
||||||
"default": None,
|
"image",
|
||||||
"tooltip": "Optional reference image to guide video generation",
|
tooltip="Optional reference image to guide video generation",
|
||||||
}),
|
optional=True,
|
||||||
"model": (
|
|
||||||
IO.COMBO,
|
|
||||||
{
|
|
||||||
"options": ["veo-2.0-generate-001"],
|
|
||||||
"default": "veo-2.0-generate-001",
|
|
||||||
"tooltip": "Veo 2 model to use for video generation",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
},
|
comfy_io.Combo.Input(
|
||||||
"hidden": {
|
"model",
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
options=["veo-2.0-generate-001"],
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
default="veo-2.0-generate-001",
|
||||||
"unique_id": "UNIQUE_ID",
|
tooltip="Veo 2 model to use for video generation",
|
||||||
},
|
optional=True,
|
||||||
}
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
comfy_io.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.VIDEO,)
|
@classmethod
|
||||||
FUNCTION = "generate_video"
|
async def execute(
|
||||||
CATEGORY = "api node/video/Veo"
|
cls,
|
||||||
DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API"
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
def generate_video(
|
|
||||||
self,
|
|
||||||
prompt,
|
prompt,
|
||||||
aspect_ratio="16:9",
|
aspect_ratio="16:9",
|
||||||
negative_prompt="",
|
negative_prompt="",
|
||||||
@ -164,8 +157,6 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
|||||||
image=None,
|
image=None,
|
||||||
model="veo-2.0-generate-001",
|
model="veo-2.0-generate-001",
|
||||||
generate_audio=False,
|
generate_audio=False,
|
||||||
unique_id: Optional[str] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
# Prepare the instances for the request
|
# Prepare the instances for the request
|
||||||
instances = []
|
instances = []
|
||||||
@ -202,6 +193,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
|||||||
if "veo-3.0" in model:
|
if "veo-3.0" in model:
|
||||||
parameters["generateAudio"] = generate_audio
|
parameters["generateAudio"] = generate_audio
|
||||||
|
|
||||||
|
auth = {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
}
|
||||||
# Initial request to start video generation
|
# Initial request to start video generation
|
||||||
initial_operation = SynchronousOperation(
|
initial_operation = SynchronousOperation(
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(
|
||||||
@ -214,10 +209,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
|||||||
instances=instances,
|
instances=instances,
|
||||||
parameters=parameters
|
parameters=parameters
|
||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_response = initial_operation.execute()
|
initial_response = await initial_operation.execute()
|
||||||
operation_name = initial_response.name
|
operation_name = initial_response.name
|
||||||
|
|
||||||
logging.info(f"Veo generation started with operation name: {operation_name}")
|
logging.info(f"Veo generation started with operation name: {operation_name}")
|
||||||
@ -248,15 +243,15 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
|||||||
request=VeoGenVidPollRequest(
|
request=VeoGenVidPollRequest(
|
||||||
operationName=operation_name
|
operationName=operation_name
|
||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
auth_kwargs=auth,
|
||||||
poll_interval=5.0,
|
poll_interval=5.0,
|
||||||
result_url_extractor=get_video_url_from_response,
|
result_url_extractor=get_video_url_from_response,
|
||||||
node_id=unique_id,
|
node_id=cls.hidden.unique_id,
|
||||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the polling operation
|
# Execute the polling operation
|
||||||
poll_response = poll_operation.execute()
|
poll_response = await poll_operation.execute()
|
||||||
|
|
||||||
# Now check for errors in the final response
|
# Now check for errors in the final response
|
||||||
# Check for error in poll response
|
# Check for error in poll response
|
||||||
@ -281,7 +276,6 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
|||||||
raise Exception(error_message)
|
raise Exception(error_message)
|
||||||
|
|
||||||
# Extract video data
|
# Extract video data
|
||||||
video_data = None
|
|
||||||
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
|
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
|
||||||
video = poll_response.response.videos[0]
|
video = poll_response.response.videos[0]
|
||||||
|
|
||||||
@ -291,9 +285,9 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
|||||||
video_data = base64.b64decode(video.bytesBase64Encoded)
|
video_data = base64.b64decode(video.bytesBase64Encoded)
|
||||||
elif hasattr(video, 'gcsUri') and video.gcsUri:
|
elif hasattr(video, 'gcsUri') and video.gcsUri:
|
||||||
# Download from URL
|
# Download from URL
|
||||||
video_url = video.gcsUri
|
async with aiohttp.ClientSession() as session:
|
||||||
video_response = requests.get(video_url)
|
async with session.get(video.gcsUri) as video_response:
|
||||||
video_data = video_response.content
|
video_data = await video_response.content.read()
|
||||||
else:
|
else:
|
||||||
raise Exception("Video returned but no data or URL was provided")
|
raise Exception("Video returned but no data or URL was provided")
|
||||||
else:
|
else:
|
||||||
@ -305,10 +299,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
|
|||||||
logging.info("Video generation completed successfully")
|
logging.info("Video generation completed successfully")
|
||||||
|
|
||||||
# Convert video data to BytesIO object
|
# Convert video data to BytesIO object
|
||||||
video_io = io.BytesIO(video_data)
|
video_io = BytesIO(video_data)
|
||||||
|
|
||||||
# Return VideoFromFile object
|
# Return VideoFromFile object
|
||||||
return (VideoFromFile(video_io),)
|
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
||||||
|
|
||||||
|
|
||||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||||
@ -324,51 +318,104 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
parent_input = super().INPUT_TYPES()
|
return comfy_io.Schema(
|
||||||
|
node_id="Veo3VideoGenerationNode",
|
||||||
# Update model options for Veo 3
|
display_name="Google Veo 3 Video Generation",
|
||||||
parent_input["optional"]["model"] = (
|
category="api node/video/Veo",
|
||||||
IO.COMBO,
|
description="Generates videos from text prompts using Google's Veo 3 API",
|
||||||
{
|
inputs=[
|
||||||
"options": ["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
comfy_io.String.Input(
|
||||||
"default": "veo-3.0-generate-001",
|
"prompt",
|
||||||
"tooltip": "Veo 3 model to use for video generation",
|
multiline=True,
|
||||||
},
|
default="",
|
||||||
|
tooltip="Text description of the video",
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=["16:9", "9:16"],
|
||||||
|
default="16:9",
|
||||||
|
tooltip="Aspect ratio of the output video",
|
||||||
|
),
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"negative_prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"duration_seconds",
|
||||||
|
default=8,
|
||||||
|
min=8,
|
||||||
|
max=8,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Boolean.Input(
|
||||||
|
"enhance_prompt",
|
||||||
|
default=True,
|
||||||
|
tooltip="Whether to enhance the prompt with AI assistance",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"person_generation",
|
||||||
|
options=["ALLOW", "BLOCK"],
|
||||||
|
default="ALLOW",
|
||||||
|
tooltip="Whether to allow generating people in the video",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFF,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed for video generation (0 for random)",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="Optional reference image to guide video generation",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
||||||
|
default="veo-3.0-generate-001",
|
||||||
|
tooltip="Veo 3 model to use for video generation",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Boolean.Input(
|
||||||
|
"generate_audio",
|
||||||
|
default=False,
|
||||||
|
tooltip="Generate audio for the video. Supported by all Veo 3 models.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
comfy_io.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add generateAudio parameter
|
|
||||||
parent_input["optional"]["generate_audio"] = (
|
|
||||||
IO.BOOLEAN,
|
|
||||||
{
|
|
||||||
"default": False,
|
|
||||||
"tooltip": "Generate audio for the video. Supported by all Veo 3 models.",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update duration constraints for Veo 3 (only 8 seconds supported)
|
class VeoExtension(ComfyExtension):
|
||||||
parent_input["optional"]["duration_seconds"] = (
|
@override
|
||||||
IO.INT,
|
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||||
{
|
return [
|
||||||
"default": 8,
|
VeoVideoGenerationNode,
|
||||||
"min": 8,
|
Veo3VideoGenerationNode,
|
||||||
"max": 8,
|
]
|
||||||
"step": 1,
|
|
||||||
"display": "number",
|
|
||||||
"tooltip": "Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return parent_input
|
async def comfy_entrypoint() -> VeoExtension:
|
||||||
|
return VeoExtension()
|
||||||
|
|
||||||
# Register the nodes
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"VeoVideoGenerationNode": VeoVideoGenerationNode,
|
|
||||||
"Veo3VideoGenerationNode": Veo3VideoGenerationNode,
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"VeoVideoGenerationNode": "Google Veo 2 Video Generation",
|
|
||||||
"Veo3VideoGenerationNode": "Google Veo 3 Video Generation",
|
|
||||||
}
|
|
||||||
|
|||||||
622
comfy_api_nodes/nodes_vidu.py
Normal file
622
comfy_api_nodes/nodes_vidu.py
Normal file
@ -0,0 +1,622 @@
|
|||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Optional, Literal, TypeVar
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io as comfy_io
|
||||||
|
from comfy_api_nodes.util.validation_utils import (
|
||||||
|
validate_aspect_ratio_closeness,
|
||||||
|
validate_image_dimensions,
|
||||||
|
validate_image_aspect_ratio_range,
|
||||||
|
get_number_of_images,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
HttpMethod,
|
||||||
|
SynchronousOperation,
|
||||||
|
PollingOperation,
|
||||||
|
EmptyRequest,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi
|
||||||
|
|
||||||
|
|
||||||
|
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
||||||
|
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
|
||||||
|
VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video"
|
||||||
|
VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video"
|
||||||
|
VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
|
||||||
|
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
class VideoModelName(str, Enum):
|
||||||
|
vidu_q1 = 'viduq1'
|
||||||
|
|
||||||
|
|
||||||
|
class AspectRatio(str, Enum):
|
||||||
|
r_16_9 = "16:9"
|
||||||
|
r_9_16 = "9:16"
|
||||||
|
r_1_1 = "1:1"
|
||||||
|
|
||||||
|
|
||||||
|
class Resolution(str, Enum):
|
||||||
|
r_1080p = "1080p"
|
||||||
|
|
||||||
|
|
||||||
|
class MovementAmplitude(str, Enum):
|
||||||
|
auto = "auto"
|
||||||
|
small = "small"
|
||||||
|
medium = "medium"
|
||||||
|
large = "large"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskCreationRequest(BaseModel):
|
||||||
|
model: VideoModelName = VideoModelName.vidu_q1
|
||||||
|
prompt: Optional[str] = Field(None, max_length=1500)
|
||||||
|
duration: Optional[Literal[5]] = 5
|
||||||
|
seed: Optional[int] = Field(0, ge=0, le=2147483647)
|
||||||
|
aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9
|
||||||
|
resolution: Optional[Resolution] = Resolution.r_1080p
|
||||||
|
movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto
|
||||||
|
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(str, Enum):
|
||||||
|
created = "created"
|
||||||
|
queueing = "queueing"
|
||||||
|
processing = "processing"
|
||||||
|
success = "success"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskCreationResponse(BaseModel):
|
||||||
|
task_id: str = Field(...)
|
||||||
|
state: TaskStatus = Field(...)
|
||||||
|
created_at: str = Field(...)
|
||||||
|
code: Optional[int] = Field(None, description="Error code")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskResult(BaseModel):
|
||||||
|
id: str = Field(..., description="Creation id")
|
||||||
|
url: str = Field(..., description="The URL of the generated results, valid for one hour")
|
||||||
|
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatusResponse(BaseModel):
|
||||||
|
state: TaskStatus = Field(...)
|
||||||
|
err_code: Optional[str] = Field(None)
|
||||||
|
creations: list[TaskResult] = Field(..., description="Generated results")
|
||||||
|
|
||||||
|
|
||||||
|
async def poll_until_finished(
|
||||||
|
auth_kwargs: dict[str, str],
|
||||||
|
api_endpoint: ApiEndpoint[Any, R],
|
||||||
|
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||||
|
estimated_duration: Optional[int] = None,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
) -> R:
|
||||||
|
return await PollingOperation(
|
||||||
|
poll_endpoint=api_endpoint,
|
||||||
|
completed_statuses=[TaskStatus.success.value],
|
||||||
|
failed_statuses=[TaskStatus.failed.value],
|
||||||
|
status_extractor=lambda response: response.state.value,
|
||||||
|
auth_kwargs=auth_kwargs,
|
||||||
|
result_url_extractor=result_url_extractor,
|
||||||
|
estimated_duration=estimated_duration,
|
||||||
|
node_id=node_id,
|
||||||
|
poll_interval=16.0,
|
||||||
|
max_poll_attempts=256,
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_url_from_response(response) -> Optional[str]:
|
||||||
|
if response.creations:
|
||||||
|
return response.creations[0].url
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_from_response(response) -> TaskResult:
|
||||||
|
if not response.creations:
|
||||||
|
error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
|
||||||
|
logging.info(error_msg)
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url)
|
||||||
|
return response.creations[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_task(
|
||||||
|
vidu_endpoint: str,
|
||||||
|
auth_kwargs: Optional[dict[str, str]],
|
||||||
|
payload: TaskCreationRequest,
|
||||||
|
estimated_duration: int,
|
||||||
|
node_id: str,
|
||||||
|
) -> R:
|
||||||
|
response = await SynchronousOperation(
|
||||||
|
endpoint=ApiEndpoint(
|
||||||
|
path=vidu_endpoint,
|
||||||
|
method=HttpMethod.POST,
|
||||||
|
request_model=TaskCreationRequest,
|
||||||
|
response_model=TaskCreationResponse,
|
||||||
|
),
|
||||||
|
request=payload,
|
||||||
|
auth_kwargs=auth_kwargs,
|
||||||
|
).execute()
|
||||||
|
if response.state == TaskStatus.failed:
|
||||||
|
error_msg = f"Vidu request failed. Code: {response.code}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg)
|
||||||
|
return await poll_until_finished(
|
||||||
|
auth_kwargs,
|
||||||
|
ApiEndpoint(
|
||||||
|
path=VIDU_GET_GENERATION_STATUS % response.task_id,
|
||||||
|
method=HttpMethod.GET,
|
||||||
|
request_model=EmptyRequest,
|
||||||
|
response_model=TaskStatusResponse,
|
||||||
|
),
|
||||||
|
result_url_extractor=get_video_url_from_response,
|
||||||
|
estimated_duration=estimated_duration,
|
||||||
|
node_id=node_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ViduTextToVideoNode(comfy_io.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return comfy_io.Schema(
|
||||||
|
node_id="ViduTextToVideoNode",
|
||||||
|
display_name="Vidu Text To Video Generation",
|
||||||
|
category="api node/video/Vidu",
|
||||||
|
description="Generate video from text prompt",
|
||||||
|
inputs=[
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=[model.value for model in VideoModelName],
|
||||||
|
default=VideoModelName.vidu_q1.value,
|
||||||
|
tooltip="Model name",
|
||||||
|
),
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A textual description for video generation",
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=5,
|
||||||
|
min=5,
|
||||||
|
max=5,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
tooltip="Duration of the output video in seconds",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed for video generation (0 for random)",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=[model.value for model in AspectRatio],
|
||||||
|
default=AspectRatio.r_16_9.value,
|
||||||
|
tooltip="The aspect ratio of the output video",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=[model.value for model in Resolution],
|
||||||
|
default=Resolution.r_1080p.value,
|
||||||
|
tooltip="Supported values may vary by model & duration",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"movement_amplitude",
|
||||||
|
options=[model.value for model in MovementAmplitude],
|
||||||
|
default=MovementAmplitude.auto.value,
|
||||||
|
tooltip="The movement amplitude of objects in the frame",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
comfy_io.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
duration: int,
|
||||||
|
seed: int,
|
||||||
|
aspect_ratio: str,
|
||||||
|
resolution: str,
|
||||||
|
movement_amplitude: str,
|
||||||
|
) -> comfy_io.NodeOutput:
|
||||||
|
if not prompt:
|
||||||
|
raise ValueError("The prompt field is required and cannot be empty.")
|
||||||
|
payload = TaskCreationRequest(
|
||||||
|
model_name=model,
|
||||||
|
prompt=prompt,
|
||||||
|
duration=duration,
|
||||||
|
seed=seed,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
resolution=resolution,
|
||||||
|
movement_amplitude=movement_amplitude,
|
||||||
|
)
|
||||||
|
auth = {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
}
|
||||||
|
results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id)
|
||||||
|
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||||
|
|
||||||
|
|
||||||
|
class ViduImageToVideoNode(comfy_io.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return comfy_io.Schema(
|
||||||
|
node_id="ViduImageToVideoNode",
|
||||||
|
display_name="Vidu Image To Video Generation",
|
||||||
|
category="api node/video/Vidu",
|
||||||
|
description="Generate video from image and optional prompt",
|
||||||
|
inputs=[
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=[model.value for model in VideoModelName],
|
||||||
|
default=VideoModelName.vidu_q1.value,
|
||||||
|
tooltip="Model name",
|
||||||
|
),
|
||||||
|
comfy_io.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="An image to be used as the start frame of the generated video",
|
||||||
|
),
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="A textual description for video generation",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=5,
|
||||||
|
min=5,
|
||||||
|
max=5,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
tooltip="Duration of the output video in seconds",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed for video generation (0 for random)",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=[model.value for model in Resolution],
|
||||||
|
default=Resolution.r_1080p.value,
|
||||||
|
tooltip="Supported values may vary by model & duration",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"movement_amplitude",
|
||||||
|
options=[model.value for model in MovementAmplitude],
|
||||||
|
default=MovementAmplitude.auto.value,
|
||||||
|
tooltip="The movement amplitude of objects in the frame",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
comfy_io.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
image: torch.Tensor,
|
||||||
|
prompt: str,
|
||||||
|
duration: int,
|
||||||
|
seed: int,
|
||||||
|
resolution: str,
|
||||||
|
movement_amplitude: str,
|
||||||
|
) -> comfy_io.NodeOutput:
|
||||||
|
if get_number_of_images(image) > 1:
|
||||||
|
raise ValueError("Only one input image is allowed.")
|
||||||
|
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
||||||
|
payload = TaskCreationRequest(
|
||||||
|
model_name=model,
|
||||||
|
prompt=prompt,
|
||||||
|
duration=duration,
|
||||||
|
seed=seed,
|
||||||
|
resolution=resolution,
|
||||||
|
movement_amplitude=movement_amplitude,
|
||||||
|
)
|
||||||
|
auth = {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
}
|
||||||
|
payload.images = await upload_images_to_comfyapi(
|
||||||
|
image,
|
||||||
|
max_images=1,
|
||||||
|
mime_type="image/png",
|
||||||
|
auth_kwargs=auth,
|
||||||
|
)
|
||||||
|
results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
||||||
|
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||||
|
|
||||||
|
|
||||||
|
class ViduReferenceVideoNode(comfy_io.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return comfy_io.Schema(
|
||||||
|
node_id="ViduReferenceVideoNode",
|
||||||
|
display_name="Vidu Reference To Video Generation",
|
||||||
|
category="api node/video/Vidu",
|
||||||
|
description="Generate video from multiple images and prompt",
|
||||||
|
inputs=[
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=[model.value for model in VideoModelName],
|
||||||
|
default=VideoModelName.vidu_q1.value,
|
||||||
|
tooltip="Model name",
|
||||||
|
),
|
||||||
|
comfy_io.Image.Input(
|
||||||
|
"images",
|
||||||
|
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
|
||||||
|
),
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A textual description for video generation",
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=5,
|
||||||
|
min=5,
|
||||||
|
max=5,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
tooltip="Duration of the output video in seconds",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed for video generation (0 for random)",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=[model.value for model in AspectRatio],
|
||||||
|
default=AspectRatio.r_16_9.value,
|
||||||
|
tooltip="The aspect ratio of the output video",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=[model.value for model in Resolution],
|
||||||
|
default=Resolution.r_1080p.value,
|
||||||
|
tooltip="Supported values may vary by model & duration",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"movement_amplitude",
|
||||||
|
options=[model.value for model in MovementAmplitude],
|
||||||
|
default=MovementAmplitude.auto.value,
|
||||||
|
tooltip="The movement amplitude of objects in the frame",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
comfy_io.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
images: torch.Tensor,
|
||||||
|
prompt: str,
|
||||||
|
duration: int,
|
||||||
|
seed: int,
|
||||||
|
aspect_ratio: str,
|
||||||
|
resolution: str,
|
||||||
|
movement_amplitude: str,
|
||||||
|
) -> comfy_io.NodeOutput:
|
||||||
|
if not prompt:
|
||||||
|
raise ValueError("The prompt field is required and cannot be empty.")
|
||||||
|
a = get_number_of_images(images)
|
||||||
|
if a > 7:
|
||||||
|
raise ValueError("Too many images, maximum allowed is 7.")
|
||||||
|
for image in images:
|
||||||
|
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
||||||
|
validate_image_dimensions(image, min_width=128, min_height=128)
|
||||||
|
payload = TaskCreationRequest(
|
||||||
|
model_name=model,
|
||||||
|
prompt=prompt,
|
||||||
|
duration=duration,
|
||||||
|
seed=seed,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
resolution=resolution,
|
||||||
|
movement_amplitude=movement_amplitude,
|
||||||
|
)
|
||||||
|
auth = {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
}
|
||||||
|
payload.images = await upload_images_to_comfyapi(
|
||||||
|
images,
|
||||||
|
max_images=7,
|
||||||
|
mime_type="image/png",
|
||||||
|
auth_kwargs=auth,
|
||||||
|
)
|
||||||
|
results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
||||||
|
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||||
|
|
||||||
|
|
||||||
|
class ViduStartEndToVideoNode(comfy_io.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return comfy_io.Schema(
|
||||||
|
node_id="ViduStartEndToVideoNode",
|
||||||
|
display_name="Vidu Start End To Video Generation",
|
||||||
|
category="api node/video/Vidu",
|
||||||
|
description="Generate a video from start and end frames and a prompt",
|
||||||
|
inputs=[
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=[model.value for model in VideoModelName],
|
||||||
|
default=VideoModelName.vidu_q1.value,
|
||||||
|
tooltip="Model name",
|
||||||
|
),
|
||||||
|
comfy_io.Image.Input(
|
||||||
|
"first_frame",
|
||||||
|
tooltip="Start frame",
|
||||||
|
),
|
||||||
|
comfy_io.Image.Input(
|
||||||
|
"end_frame",
|
||||||
|
tooltip="End frame",
|
||||||
|
),
|
||||||
|
comfy_io.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="A textual description for video generation",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=5,
|
||||||
|
min=5,
|
||||||
|
max=5,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
tooltip="Duration of the output video in seconds",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=comfy_io.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed for video generation (0 for random)",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=[model.value for model in Resolution],
|
||||||
|
default=Resolution.r_1080p.value,
|
||||||
|
tooltip="Supported values may vary by model & duration",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
comfy_io.Combo.Input(
|
||||||
|
"movement_amplitude",
|
||||||
|
options=[model.value for model in MovementAmplitude],
|
||||||
|
default=MovementAmplitude.auto.value,
|
||||||
|
tooltip="The movement amplitude of objects in the frame",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
comfy_io.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
comfy_io.Hidden.auth_token_comfy_org,
|
||||||
|
comfy_io.Hidden.api_key_comfy_org,
|
||||||
|
comfy_io.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
first_frame: torch.Tensor,
|
||||||
|
end_frame: torch.Tensor,
|
||||||
|
prompt: str,
|
||||||
|
duration: int,
|
||||||
|
seed: int,
|
||||||
|
resolution: str,
|
||||||
|
movement_amplitude: str,
|
||||||
|
) -> comfy_io.NodeOutput:
|
||||||
|
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||||
|
payload = TaskCreationRequest(
|
||||||
|
model_name=model,
|
||||||
|
prompt=prompt,
|
||||||
|
duration=duration,
|
||||||
|
seed=seed,
|
||||||
|
resolution=resolution,
|
||||||
|
movement_amplitude=movement_amplitude,
|
||||||
|
)
|
||||||
|
auth = {
|
||||||
|
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||||
|
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||||
|
}
|
||||||
|
payload.images = [
|
||||||
|
(await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0]
|
||||||
|
for frame in (first_frame, end_frame)
|
||||||
|
]
|
||||||
|
results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id)
|
||||||
|
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||||
|
|
||||||
|
|
||||||
|
class ViduExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
ViduTextToVideoNode,
|
||||||
|
ViduImageToVideoNode,
|
||||||
|
ViduReferenceVideoNode,
|
||||||
|
ViduStartEndToVideoNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ViduExtension:
|
||||||
|
return ViduExtension()
|
||||||
@ -53,6 +53,53 @@ def validate_image_aspect_ratio(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_image_aspect_ratio_range(
|
||||||
|
image: torch.Tensor,
|
||||||
|
min_ratio: tuple[float, float], # e.g. (1, 4)
|
||||||
|
max_ratio: tuple[float, float], # e.g. (4, 1)
|
||||||
|
*,
|
||||||
|
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||||
|
) -> float:
|
||||||
|
a1, b1 = min_ratio
|
||||||
|
a2, b2 = max_ratio
|
||||||
|
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
|
||||||
|
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
|
||||||
|
lo, hi = (a1 / b1), (a2 / b2)
|
||||||
|
if lo > hi:
|
||||||
|
lo, hi = hi, lo
|
||||||
|
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
|
||||||
|
w, h = get_image_dimensions(image)
|
||||||
|
if w <= 0 or h <= 0:
|
||||||
|
raise ValueError(f"Invalid image dimensions: {w}x{h}")
|
||||||
|
ar = w / h
|
||||||
|
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
|
||||||
|
if not ok:
|
||||||
|
op = "<" if strict else "≤"
|
||||||
|
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
|
||||||
|
return ar
|
||||||
|
|
||||||
|
|
||||||
|
def validate_aspect_ratio_closeness(
|
||||||
|
start_img,
|
||||||
|
end_img,
|
||||||
|
min_rel: float,
|
||||||
|
max_rel: float,
|
||||||
|
*,
|
||||||
|
strict: bool = False, # True => exclusive, False => inclusive
|
||||||
|
) -> None:
|
||||||
|
w1, h1 = get_image_dimensions(start_img)
|
||||||
|
w2, h2 = get_image_dimensions(end_img)
|
||||||
|
if min(w1, h1, w2, h2) <= 0:
|
||||||
|
raise ValueError("Invalid image dimensions")
|
||||||
|
ar1 = w1 / h1
|
||||||
|
ar2 = w2 / h2
|
||||||
|
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
|
||||||
|
closeness = max(ar1, ar2) / min(ar1, ar2)
|
||||||
|
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
|
||||||
|
if (closeness >= limit) if strict else (closeness > limit):
|
||||||
|
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.")
|
||||||
|
|
||||||
|
|
||||||
def validate_video_dimensions(
|
def validate_video_dimensions(
|
||||||
video: VideoInput,
|
video: VideoInput,
|
||||||
min_width: Optional[int] = None,
|
min_width: Optional[int] = None,
|
||||||
@ -98,3 +145,9 @@ def validate_video_duration(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Video duration must be at most {max_duration}s, got {duration}s"
|
f"Video duration must be at most {max_duration}s, got {duration}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_number_of_images(images):
|
||||||
|
if isinstance(images, torch.Tensor):
|
||||||
|
return images.shape[0] if images.ndim >= 4 else 1
|
||||||
|
return len(images)
|
||||||
|
|||||||
@ -1,53 +1,65 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy import node_helpers
|
from comfy import node_helpers
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
class TextEncodeAceStepAudio:
|
class TextEncodeAceStepAudio(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"clip": ("CLIP",),
|
node_id="TextEncodeAceStepAudio",
|
||||||
"tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
category="conditioning",
|
||||||
"lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
inputs=[
|
||||||
"lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Clip.Input("clip"),
|
||||||
}}
|
io.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||||
|
io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
|
||||||
|
io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[io.Conditioning.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
@classmethod
|
||||||
FUNCTION = "encode"
|
def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "conditioning"
|
|
||||||
|
|
||||||
def encode(self, clip, tags, lyrics, lyrics_strength):
|
|
||||||
tokens = clip.tokenize(tags, lyrics=lyrics)
|
tokens = clip.tokenize(tags, lyrics=lyrics)
|
||||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
|
||||||
return (conditioning,)
|
return io.NodeOutput(conditioning)
|
||||||
|
|
||||||
|
|
||||||
class EmptyAceStepLatentAudio:
|
class EmptyAceStepLatentAudio(io.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.device = comfy.model_management.intermediate_device()
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="EmptyAceStepLatentAudio",
|
||||||
|
category="latent/audio",
|
||||||
|
inputs=[
|
||||||
|
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||||
|
io.Int.Input(
|
||||||
|
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||||
|
),
|
||||||
|
|
||||||
|
],
|
||||||
|
outputs=[io.Latent.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, seconds, batch_size) -> io.NodeOutput:
|
||||||
return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
|
||||||
FUNCTION = "generate"
|
|
||||||
|
|
||||||
CATEGORY = "latent/audio"
|
|
||||||
|
|
||||||
def generate(self, seconds, batch_size):
|
|
||||||
length = int(seconds * 44100 / 512 / 8)
|
length = int(seconds * 44100 / 512 / 8)
|
||||||
latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
|
latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device())
|
||||||
return ({"samples": latent, "type": "audio"},)
|
return io.NodeOutput({"samples": latent, "type": "audio"})
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class AceExtension(ComfyExtension):
|
||||||
"TextEncodeAceStepAudio": TextEncodeAceStepAudio,
|
@override
|
||||||
"EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
}
|
return [
|
||||||
|
TextEncodeAceStepAudio,
|
||||||
|
EmptyAceStepLatentAudio,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> AceExtension:
|
||||||
|
return AceExtension()
|
||||||
|
|||||||
@ -1,8 +1,13 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tqdm.auto import trange
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import comfy.model_patcher
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import torch
|
from comfy.k_diffusion.sampling import to_d
|
||||||
import numpy as np
|
from comfy_api.latest import ComfyExtension, io
|
||||||
from tqdm.auto import trange
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -33,30 +38,29 @@ def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SamplerLCMUpscale:
|
class SamplerLCMUpscale(io.ComfyNode):
|
||||||
upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
UPSCALE_METHODS = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required":
|
return io.Schema(
|
||||||
{"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}),
|
node_id="SamplerLCMUpscale",
|
||||||
"scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}),
|
category="sampling/custom_sampling/samplers",
|
||||||
"upscale_method": (s.upscale_methods,),
|
inputs=[
|
||||||
}
|
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01),
|
||||||
}
|
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1),
|
||||||
RETURN_TYPES = ("SAMPLER",)
|
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
|
||||||
CATEGORY = "sampling/custom_sampling/samplers"
|
],
|
||||||
|
outputs=[io.Sampler.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "get_sampler"
|
@classmethod
|
||||||
|
def execute(cls, scale_ratio, scale_steps, upscale_method) -> io.NodeOutput:
|
||||||
def get_sampler(self, scale_ratio, scale_steps, upscale_method):
|
|
||||||
if scale_steps < 0:
|
if scale_steps < 0:
|
||||||
scale_steps = None
|
scale_steps = None
|
||||||
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
|
sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method})
|
||||||
return (sampler, )
|
return io.NodeOutput(sampler)
|
||||||
|
|
||||||
from comfy.k_diffusion.sampling import to_d
|
|
||||||
import comfy.model_patcher
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
@ -82,30 +86,36 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SamplerEulerCFGpp:
|
class SamplerEulerCFGpp(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required":
|
return io.Schema(
|
||||||
{"version": (["regular", "alternative"],),}
|
node_id="SamplerEulerCFGpp",
|
||||||
}
|
display_name="SamplerEulerCFG++",
|
||||||
RETURN_TYPES = ("SAMPLER",)
|
category="_for_testing", # "sampling/custom_sampling/samplers"
|
||||||
# CATEGORY = "sampling/custom_sampling/samplers"
|
inputs=[
|
||||||
CATEGORY = "_for_testing"
|
io.Combo.Input("version", options=["regular", "alternative"]),
|
||||||
|
],
|
||||||
|
outputs=[io.Sampler.Output()],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "get_sampler"
|
@classmethod
|
||||||
|
def execute(cls, version) -> io.NodeOutput:
|
||||||
def get_sampler(self, version):
|
|
||||||
if version == "alternative":
|
if version == "alternative":
|
||||||
sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
|
sampler = comfy.samplers.KSAMPLER(sample_euler_pp)
|
||||||
else:
|
else:
|
||||||
sampler = comfy.samplers.ksampler("euler_cfg_pp")
|
sampler = comfy.samplers.ksampler("euler_cfg_pp")
|
||||||
return (sampler, )
|
return io.NodeOutput(sampler)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"SamplerLCMUpscale": SamplerLCMUpscale,
|
|
||||||
"SamplerEulerCFGpp": SamplerEulerCFGpp,
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
class AdvancedSamplersExtension(ComfyExtension):
|
||||||
"SamplerEulerCFGpp": "SamplerEulerCFG++",
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
SamplerLCMUpscale,
|
||||||
|
SamplerEulerCFGpp,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> AdvancedSamplersExtension:
|
||||||
|
return AdvancedSamplersExtension()
|
||||||
|
|||||||
@ -1,4 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def project(v0, v1):
|
def project(v0, v1):
|
||||||
@ -8,23 +12,46 @@ def project(v0, v1):
|
|||||||
return v0_parallel, v0_orthogonal
|
return v0_parallel, v0_orthogonal
|
||||||
|
|
||||||
|
|
||||||
class APG:
|
class APG(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="APG",
|
||||||
"model": ("MODEL",),
|
display_name="Adaptive Projected Guidance",
|
||||||
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
|
category="sampling/custom_sampling",
|
||||||
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
|
inputs=[
|
||||||
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip": "Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
|
io.Model.Input("model"),
|
||||||
}
|
io.Float.Input(
|
||||||
}
|
"eta",
|
||||||
|
default=1.0,
|
||||||
|
min=-10.0,
|
||||||
|
max=10.0,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"norm_threshold",
|
||||||
|
default=5.0,
|
||||||
|
min=0.0,
|
||||||
|
max=50.0,
|
||||||
|
step=0.1,
|
||||||
|
tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"momentum",
|
||||||
|
default=0.0,
|
||||||
|
min=-5.0,
|
||||||
|
max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
tooltip= "Controls a running average of guidance during diffusion, disabled at a setting of 0.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
outputs=[io.Model.Output()],
|
||||||
FUNCTION = "patch"
|
)
|
||||||
CATEGORY = "sampling/custom_sampling"
|
|
||||||
|
|
||||||
def patch(self, model, eta, norm_threshold, momentum):
|
@classmethod
|
||||||
|
def execute(cls, model, eta, norm_threshold, momentum) -> io.NodeOutput:
|
||||||
running_avg = 0
|
running_avg = 0
|
||||||
prev_sigma = None
|
prev_sigma = None
|
||||||
|
|
||||||
@ -68,13 +95,15 @@ class APG:
|
|||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||||
return (m,)
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class ApgExtension(ComfyExtension):
|
||||||
"APG": APG,
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
APG,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
async def comfy_entrypoint() -> ApgExtension:
|
||||||
"APG": "Adaptive Projected Guidance",
|
return ApgExtension()
|
||||||
}
|
|
||||||
|
|||||||
@ -1,3 +1,7 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
def attention_multiply(attn, model, q, k, v, out):
|
def attention_multiply(attn, model, q, k, v, out):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
@ -16,57 +20,71 @@ def attention_multiply(attn, model, q, k, v, out):
|
|||||||
return m
|
return m
|
||||||
|
|
||||||
|
|
||||||
class UNetSelfAttentionMultiply:
|
class UNetSelfAttentionMultiply(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
node_id="UNetSelfAttentionMultiply",
|
||||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
category="_for_testing/attention_experiments",
|
||||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
inputs=[
|
||||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Model.Input("model"),
|
||||||
}}
|
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
FUNCTION = "patch"
|
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "_for_testing/attention_experiments"
|
@classmethod
|
||||||
|
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
||||||
def patch(self, model, q, k, v, out):
|
|
||||||
m = attention_multiply("attn1", model, q, k, v, out)
|
m = attention_multiply("attn1", model, q, k, v, out)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class UNetCrossAttentionMultiply:
|
|
||||||
|
class UNetCrossAttentionMultiply(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
node_id="UNetCrossAttentionMultiply",
|
||||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
category="_for_testing/attention_experiments",
|
||||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
inputs=[
|
||||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Model.Input("model"),
|
||||||
}}
|
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
FUNCTION = "patch"
|
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "_for_testing/attention_experiments"
|
@classmethod
|
||||||
|
def execute(cls, model, q, k, v, out) -> io.NodeOutput:
|
||||||
def patch(self, model, q, k, v, out):
|
|
||||||
m = attention_multiply("attn2", model, q, k, v, out)
|
m = attention_multiply("attn2", model, q, k, v, out)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class CLIPAttentionMultiply:
|
|
||||||
|
class CLIPAttentionMultiply(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": { "clip": ("CLIP",),
|
return io.Schema(
|
||||||
"q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
node_id="CLIPAttentionMultiply",
|
||||||
"k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
category="_for_testing/attention_experiments",
|
||||||
"v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
inputs=[
|
||||||
"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Clip.Input("clip"),
|
||||||
}}
|
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
RETURN_TYPES = ("CLIP",)
|
io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
FUNCTION = "patch"
|
io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[io.Clip.Output()],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "_for_testing/attention_experiments"
|
@classmethod
|
||||||
|
def execute(cls, clip, q, k, v, out) -> io.NodeOutput:
|
||||||
def patch(self, clip, q, k, v, out):
|
|
||||||
m = clip.clone()
|
m = clip.clone()
|
||||||
sd = m.patcher.model_state_dict()
|
sd = m.patcher.model_state_dict()
|
||||||
|
|
||||||
@ -79,23 +97,28 @@ class CLIPAttentionMultiply:
|
|||||||
m.add_patches({key: (None,)}, 0.0, v)
|
m.add_patches({key: (None,)}, 0.0, v)
|
||||||
if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
|
if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"):
|
||||||
m.add_patches({key: (None,)}, 0.0, out)
|
m.add_patches({key: (None,)}, 0.0, out)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class UNetTemporalAttentionMultiply:
|
|
||||||
|
class UNetTemporalAttentionMultiply(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls) -> io.Schema:
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
node_id="UNetTemporalAttentionMultiply",
|
||||||
"self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
category="_for_testing/attention_experiments",
|
||||||
"cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
inputs=[
|
||||||
"cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.Model.Input("model"),
|
||||||
}}
|
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
FUNCTION = "patch"
|
io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[io.Model.Output()],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "_for_testing/attention_experiments"
|
@classmethod
|
||||||
|
def execute(cls, model, self_structural, self_temporal, cross_structural, cross_temporal) -> io.NodeOutput:
|
||||||
def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal):
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
sd = model.model_state_dict()
|
sd = model.model_state_dict()
|
||||||
|
|
||||||
@ -110,11 +133,18 @@ class UNetTemporalAttentionMultiply:
|
|||||||
m.add_patches({k: (None,)}, 0.0, cross_temporal)
|
m.add_patches({k: (None,)}, 0.0, cross_temporal)
|
||||||
else:
|
else:
|
||||||
m.add_patches({k: (None,)}, 0.0, cross_structural)
|
m.add_patches({k: (None,)}, 0.0, cross_structural)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"UNetSelfAttentionMultiply": UNetSelfAttentionMultiply,
|
class AttentionMultiplyExtension(ComfyExtension):
|
||||||
"UNetCrossAttentionMultiply": UNetCrossAttentionMultiply,
|
@override
|
||||||
"CLIPAttentionMultiply": CLIPAttentionMultiply,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply,
|
return [
|
||||||
}
|
UNetSelfAttentionMultiply,
|
||||||
|
UNetCrossAttentionMultiply,
|
||||||
|
CLIPAttentionMultiply,
|
||||||
|
UNetTemporalAttentionMultiply,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> AttentionMultiplyExtension:
|
||||||
|
return AttentionMultiplyExtension()
|
||||||
|
|||||||
@ -379,6 +379,27 @@ class LoadAudio:
|
|||||||
return "Invalid audio file: {}".format(audio)
|
return "Invalid audio file: {}".format(audio)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
class RecordAudio:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"audio": ("AUDIO_RECORD", {})}}
|
||||||
|
|
||||||
|
CATEGORY = "audio"
|
||||||
|
|
||||||
|
RETURN_TYPES = ("AUDIO", )
|
||||||
|
FUNCTION = "load"
|
||||||
|
|
||||||
|
def load(self, audio):
|
||||||
|
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||||
|
try:
|
||||||
|
import torchaudio # pylint: disable=import-error
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
raise TorchAudioNotFoundError()
|
||||||
|
|
||||||
|
waveform, sample_rate = torchaudio.load(audio_path)
|
||||||
|
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||||
|
return (audio, )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"EmptyLatentAudio": EmptyLatentAudio,
|
"EmptyLatentAudio": EmptyLatentAudio,
|
||||||
@ -390,6 +411,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LoadAudio": LoadAudio,
|
"LoadAudio": LoadAudio,
|
||||||
"PreviewAudio": PreviewAudio,
|
"PreviewAudio": PreviewAudio,
|
||||||
"ConditioningStableAudio": ConditioningStableAudio,
|
"ConditioningStableAudio": ConditioningStableAudio,
|
||||||
|
"RecordAudio": RecordAudio,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
@ -401,4 +423,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"SaveAudio": "Save Audio (FLAC)",
|
"SaveAudio": "Save Audio (FLAC)",
|
||||||
"SaveAudioMP3": "Save Audio (MP3)",
|
"SaveAudioMP3": "Save Audio (MP3)",
|
||||||
"SaveAudioOpus": "Save Audio (Opus)",
|
"SaveAudioOpus": "Save Audio (Opus)",
|
||||||
|
"RecordAudio": "Record Audio",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -83,9 +83,28 @@ class FluxKontextImageScale:
|
|||||||
return (image,)
|
return (image,)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxKontextMultiReferenceLatentMethod:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"conditioning": ("CONDITIONING", ),
|
||||||
|
"reference_latents_method": (("offset", "index"), ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "append"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning/flux"
|
||||||
|
|
||||||
|
def append(self, conditioning, reference_latents_method):
|
||||||
|
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
|
||||||
|
return (c, )
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
||||||
"FluxGuidance": FluxGuidance,
|
"FluxGuidance": FluxGuidance,
|
||||||
"FluxDisableGuidance": FluxDisableGuidance,
|
"FluxDisableGuidance": FluxDisableGuidance,
|
||||||
"FluxKontextImageScale": FluxKontextImageScale,
|
"FluxKontextImageScale": FluxKontextImageScale,
|
||||||
|
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -172,7 +172,7 @@ class LTXVAddGuide:
|
|||||||
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
|
||||||
|
|
||||||
mask = torch.full(
|
mask = torch.full(
|
||||||
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
|
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
|
||||||
1.0 - strength,
|
1.0 - strength,
|
||||||
dtype=noise_mask.dtype,
|
dtype=noise_mask.dtype,
|
||||||
device=noise_mask.device,
|
device=noise_mask.device,
|
||||||
|
|||||||
@ -1,81 +1,91 @@
|
|||||||
import re
|
import re
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy.comfy_types.node_typing import IO
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
class StringConcatenate():
|
class StringConcatenate(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="StringConcatenate",
|
||||||
"string_a": (IO.STRING, {"multiline": True}),
|
display_name="Concatenate",
|
||||||
"string_b": (IO.STRING, {"multiline": True}),
|
category="utils/string",
|
||||||
"delimiter": (IO.STRING, {"multiline": False, "default": ""})
|
inputs=[
|
||||||
}
|
io.String.Input("string_a", multiline=True),
|
||||||
}
|
io.String.Input("string_b", multiline=True),
|
||||||
|
io.String.Input("delimiter", multiline=False, default=""),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.String.Output(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.STRING,)
|
|
||||||
FUNCTION = "execute"
|
|
||||||
CATEGORY = "utils/string"
|
|
||||||
|
|
||||||
def execute(self, string_a, string_b, delimiter, **kwargs):
|
|
||||||
return delimiter.join((string_a, string_b)),
|
|
||||||
|
|
||||||
|
|
||||||
class StringSubstring():
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, string_a, string_b, delimiter):
|
||||||
return {
|
return io.NodeOutput(delimiter.join((string_a, string_b)))
|
||||||
"required": {
|
|
||||||
"string": (IO.STRING, {"multiline": True}),
|
|
||||||
"start": (IO.INT, {}),
|
|
||||||
"end": (IO.INT, {}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.STRING,)
|
|
||||||
FUNCTION = "execute"
|
|
||||||
CATEGORY = "utils/string"
|
|
||||||
|
|
||||||
def execute(self, string, start, end, **kwargs):
|
|
||||||
return string[start:end],
|
|
||||||
|
|
||||||
|
|
||||||
class StringLength():
|
class StringSubstring(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="StringSubstring",
|
||||||
"string": (IO.STRING, {"multiline": True})
|
display_name="Substring",
|
||||||
}
|
category="utils/string",
|
||||||
}
|
inputs=[
|
||||||
|
io.String.Input("string", multiline=True),
|
||||||
|
io.Int.Input("start"),
|
||||||
|
io.Int.Input("end"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.String.Output(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.INT,)
|
|
||||||
RETURN_NAMES = ("length",)
|
|
||||||
FUNCTION = "execute"
|
|
||||||
CATEGORY = "utils/string"
|
|
||||||
|
|
||||||
def execute(self, string, **kwargs):
|
|
||||||
length = len(string)
|
|
||||||
|
|
||||||
return length,
|
|
||||||
|
|
||||||
|
|
||||||
class CaseConverter():
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, string, start, end):
|
||||||
return {
|
return io.NodeOutput(string[start:end])
|
||||||
"required": {
|
|
||||||
"string": (IO.STRING, {"multiline": True}),
|
|
||||||
"mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.STRING,)
|
|
||||||
FUNCTION = "execute"
|
|
||||||
CATEGORY = "utils/string"
|
|
||||||
|
|
||||||
def execute(self, string, mode, **kwargs):
|
class StringLength(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="StringLength",
|
||||||
|
display_name="Length",
|
||||||
|
category="utils/string",
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("string", multiline=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Int.Output(display_name="length"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, string):
|
||||||
|
return io.NodeOutput(len(string))
|
||||||
|
|
||||||
|
|
||||||
|
class CaseConverter(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CaseConverter",
|
||||||
|
display_name="Case Converter",
|
||||||
|
category="utils/string",
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("string", multiline=True),
|
||||||
|
io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.String.Output(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, string, mode):
|
||||||
if mode == "UPPERCASE":
|
if mode == "UPPERCASE":
|
||||||
result = string.upper()
|
result = string.upper()
|
||||||
elif mode == "lowercase":
|
elif mode == "lowercase":
|
||||||
@ -87,24 +97,27 @@ class CaseConverter():
|
|||||||
else:
|
else:
|
||||||
result = string
|
result = string
|
||||||
|
|
||||||
return result,
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
class StringTrim():
|
class StringTrim(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="StringTrim",
|
||||||
"string": (IO.STRING, {"multiline": True}),
|
display_name="Trim",
|
||||||
"mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]})
|
category="utils/string",
|
||||||
}
|
inputs=[
|
||||||
}
|
io.String.Input("string", multiline=True),
|
||||||
|
io.Combo.Input("mode", options=["Both", "Left", "Right"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.String.Output(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.STRING,)
|
@classmethod
|
||||||
FUNCTION = "execute"
|
def execute(cls, string, mode):
|
||||||
CATEGORY = "utils/string"
|
|
||||||
|
|
||||||
def execute(self, string, mode, **kwargs):
|
|
||||||
if mode == "Both":
|
if mode == "Both":
|
||||||
result = string.strip()
|
result = string.strip()
|
||||||
elif mode == "Left":
|
elif mode == "Left":
|
||||||
@ -114,71 +127,78 @@ class StringTrim():
|
|||||||
else:
|
else:
|
||||||
result = string
|
result = string
|
||||||
|
|
||||||
return result,
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
class StringReplace():
|
class StringReplace(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="StringReplace",
|
||||||
"string": (IO.STRING, {"multiline": True}),
|
display_name="Replace",
|
||||||
"find": (IO.STRING, {"multiline": True}),
|
category="utils/string",
|
||||||
"replace": (IO.STRING, {"multiline": True})
|
inputs=[
|
||||||
}
|
io.String.Input("string", multiline=True),
|
||||||
}
|
io.String.Input("find", multiline=True),
|
||||||
|
io.String.Input("replace", multiline=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.String.Output(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.STRING,)
|
|
||||||
FUNCTION = "execute"
|
|
||||||
CATEGORY = "utils/string"
|
|
||||||
|
|
||||||
def execute(self, string, find, replace, **kwargs):
|
|
||||||
result = string.replace(find, replace)
|
|
||||||
return result,
|
|
||||||
|
|
||||||
|
|
||||||
class StringContains():
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, string, find, replace):
|
||||||
return {
|
return io.NodeOutput(string.replace(find, replace))
|
||||||
"required": {
|
|
||||||
"string": (IO.STRING, {"multiline": True}),
|
|
||||||
"substring": (IO.STRING, {"multiline": True}),
|
|
||||||
"case_sensitive": (IO.BOOLEAN, {"default": True})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.BOOLEAN,)
|
|
||||||
RETURN_NAMES = ("contains",)
|
|
||||||
FUNCTION = "execute"
|
|
||||||
CATEGORY = "utils/string"
|
|
||||||
|
|
||||||
def execute(self, string, substring, case_sensitive, **kwargs):
|
class StringContains(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="StringContains",
|
||||||
|
display_name="Contains",
|
||||||
|
category="utils/string",
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("string", multiline=True),
|
||||||
|
io.String.Input("substring", multiline=True),
|
||||||
|
io.Boolean.Input("case_sensitive", default=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Boolean.Output(display_name="contains"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, string, substring, case_sensitive):
|
||||||
if case_sensitive:
|
if case_sensitive:
|
||||||
contains = substring in string
|
contains = substring in string
|
||||||
else:
|
else:
|
||||||
contains = substring.lower() in string.lower()
|
contains = substring.lower() in string.lower()
|
||||||
|
|
||||||
return contains,
|
return io.NodeOutput(contains)
|
||||||
|
|
||||||
|
|
||||||
class StringCompare():
|
class StringCompare(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="StringCompare",
|
||||||
"string_a": (IO.STRING, {"multiline": True}),
|
display_name="Compare",
|
||||||
"string_b": (IO.STRING, {"multiline": True}),
|
category="utils/string",
|
||||||
"mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}),
|
inputs=[
|
||||||
"case_sensitive": (IO.BOOLEAN, {"default": True})
|
io.String.Input("string_a", multiline=True),
|
||||||
}
|
io.String.Input("string_b", multiline=True),
|
||||||
}
|
io.Combo.Input("mode", options=["Starts With", "Ends With", "Equal"]),
|
||||||
|
io.Boolean.Input("case_sensitive", default=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Boolean.Output(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.BOOLEAN,)
|
@classmethod
|
||||||
FUNCTION = "execute"
|
def execute(cls, string_a, string_b, mode, case_sensitive):
|
||||||
CATEGORY = "utils/string"
|
|
||||||
|
|
||||||
def execute(self, string_a, string_b, mode, case_sensitive, **kwargs):
|
|
||||||
if case_sensitive:
|
if case_sensitive:
|
||||||
a = string_a
|
a = string_a
|
||||||
b = string_b
|
b = string_b
|
||||||
@ -187,32 +207,34 @@ class StringCompare():
|
|||||||
b = string_b.lower()
|
b = string_b.lower()
|
||||||
|
|
||||||
if mode == "Equal":
|
if mode == "Equal":
|
||||||
return a == b,
|
return io.NodeOutput(a == b)
|
||||||
elif mode == "Starts With":
|
elif mode == "Starts With":
|
||||||
return a.startswith(b),
|
return io.NodeOutput(a.startswith(b))
|
||||||
elif mode == "Ends With":
|
elif mode == "Ends With":
|
||||||
return a.endswith(b),
|
return io.NodeOutput(a.endswith(b))
|
||||||
|
|
||||||
|
|
||||||
class RegexMatch():
|
class RegexMatch(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="RegexMatch",
|
||||||
"string": (IO.STRING, {"multiline": True}),
|
display_name="Regex Match",
|
||||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
category="utils/string",
|
||||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
inputs=[
|
||||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
io.String.Input("string", multiline=True),
|
||||||
"dotall": (IO.BOOLEAN, {"default": False})
|
io.String.Input("regex_pattern", multiline=True),
|
||||||
}
|
io.Boolean.Input("case_insensitive", default=True),
|
||||||
}
|
io.Boolean.Input("multiline", default=False),
|
||||||
|
io.Boolean.Input("dotall", default=False),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Boolean.Output(display_name="matches"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.BOOLEAN,)
|
@classmethod
|
||||||
RETURN_NAMES = ("matches",)
|
def execute(cls, string, regex_pattern, case_insensitive, multiline, dotall):
|
||||||
FUNCTION = "execute"
|
|
||||||
CATEGORY = "utils/string"
|
|
||||||
|
|
||||||
def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs):
|
|
||||||
flags = 0
|
flags = 0
|
||||||
|
|
||||||
if case_insensitive:
|
if case_insensitive:
|
||||||
@ -229,29 +251,32 @@ class RegexMatch():
|
|||||||
except re.error:
|
except re.error:
|
||||||
result = False
|
result = False
|
||||||
|
|
||||||
return result,
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
class RegexExtract():
|
class RegexExtract(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="RegexExtract",
|
||||||
"string": (IO.STRING, {"multiline": True}),
|
display_name="Regex Extract",
|
||||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
category="utils/string",
|
||||||
"mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}),
|
inputs=[
|
||||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
io.String.Input("string", multiline=True),
|
||||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
io.String.Input("regex_pattern", multiline=True),
|
||||||
"dotall": (IO.BOOLEAN, {"default": False}),
|
io.Combo.Input("mode", options=["First Match", "All Matches", "First Group", "All Groups"]),
|
||||||
"group_index": (IO.INT, {"default": 1, "min": 0, "max": 100})
|
io.Boolean.Input("case_insensitive", default=True),
|
||||||
}
|
io.Boolean.Input("multiline", default=False),
|
||||||
}
|
io.Boolean.Input("dotall", default=False),
|
||||||
|
io.Int.Input("group_index", default=1, min=0, max=100),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.String.Output(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = (IO.STRING,)
|
@classmethod
|
||||||
FUNCTION = "execute"
|
def execute(cls, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index):
|
||||||
CATEGORY = "utils/string"
|
|
||||||
|
|
||||||
def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs):
|
|
||||||
join_delimiter = "\n"
|
join_delimiter = "\n"
|
||||||
|
|
||||||
flags = 0
|
flags = 0
|
||||||
@ -300,33 +325,33 @@ class RegexExtract():
|
|||||||
except re.error:
|
except re.error:
|
||||||
result = ""
|
result = ""
|
||||||
|
|
||||||
return result,
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
class RegexReplace():
|
class RegexReplace(io.ComfyNode):
|
||||||
DESCRIPTION = "Find and replace text using regex patterns."
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="RegexReplace",
|
||||||
|
display_name="Regex Replace",
|
||||||
|
category="utils/string",
|
||||||
|
description="Find and replace text using regex patterns.",
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("string", multiline=True),
|
||||||
|
io.String.Input("regex_pattern", multiline=True),
|
||||||
|
io.String.Input("replace", multiline=True),
|
||||||
|
io.Boolean.Input("case_insensitive", default=True, optional=True),
|
||||||
|
io.Boolean.Input("multiline", default=False, optional=True),
|
||||||
|
io.Boolean.Input("dotall", default=False, optional=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."),
|
||||||
|
io.Int.Input("count", default=0, min=0, max=100, optional=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.String.Output(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0):
|
||||||
return {
|
|
||||||
"required": {
|
|
||||||
"string": (IO.STRING, {"multiline": True}),
|
|
||||||
"regex_pattern": (IO.STRING, {"multiline": True}),
|
|
||||||
"replace": (IO.STRING, {"multiline": True}),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"case_insensitive": (IO.BOOLEAN, {"default": True}),
|
|
||||||
"multiline": (IO.BOOLEAN, {"default": False}),
|
|
||||||
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
|
|
||||||
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.STRING,)
|
|
||||||
FUNCTION = "execute"
|
|
||||||
CATEGORY = "utils/string"
|
|
||||||
|
|
||||||
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
|
|
||||||
flags = 0
|
flags = 0
|
||||||
|
|
||||||
if case_insensitive:
|
if case_insensitive:
|
||||||
@ -336,33 +361,26 @@ class RegexReplace():
|
|||||||
if dotall:
|
if dotall:
|
||||||
flags |= re.DOTALL
|
flags |= re.DOTALL
|
||||||
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
|
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
|
||||||
return result,
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class StringExtension(ComfyExtension):
|
||||||
"StringConcatenate": StringConcatenate,
|
@override
|
||||||
"StringSubstring": StringSubstring,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"StringLength": StringLength,
|
return [
|
||||||
"CaseConverter": CaseConverter,
|
StringConcatenate,
|
||||||
"StringTrim": StringTrim,
|
StringSubstring,
|
||||||
"StringReplace": StringReplace,
|
StringLength,
|
||||||
"StringContains": StringContains,
|
CaseConverter,
|
||||||
"StringCompare": StringCompare,
|
StringTrim,
|
||||||
"RegexMatch": RegexMatch,
|
StringReplace,
|
||||||
"RegexExtract": RegexExtract,
|
StringContains,
|
||||||
"RegexReplace": RegexReplace,
|
StringCompare,
|
||||||
}
|
RegexMatch,
|
||||||
|
RegexExtract,
|
||||||
|
RegexReplace,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"StringConcatenate": "Concatenate",
|
async def comfy_entrypoint() -> StringExtension:
|
||||||
"StringSubstring": "Substring",
|
return StringExtension()
|
||||||
"StringLength": "Length",
|
|
||||||
"CaseConverter": "Case Converter",
|
|
||||||
"StringTrim": "Trim",
|
|
||||||
"StringReplace": "Replace",
|
|
||||||
"StringContains": "Contains",
|
|
||||||
"StringCompare": "Compare",
|
|
||||||
"RegexMatch": "Regex Match",
|
|
||||||
"RegexExtract": "Regex Extract",
|
|
||||||
"RegexReplace": "Regex Replace",
|
|
||||||
}
|
|
||||||
|
|||||||
@ -4,37 +4,45 @@ from typing import Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
import comfy.clip_vision
|
import comfy.clip_vision
|
||||||
|
import comfy.clip_vision
|
||||||
|
import comfy.latent_formats
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from comfy import node_helpers
|
from comfy import node_helpers
|
||||||
from comfy.nodes import base_nodes as nodes
|
from comfy.nodes import base_nodes as nodes
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
class WanImageToVideo:
|
class WanImageToVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"positive": ("CONDITIONING",),
|
return io.Schema(
|
||||||
"negative": ("CONDITIONING",),
|
node_id="WanImageToVideo",
|
||||||
"vae": ("VAE",),
|
category="conditioning/video_models",
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
inputs=[
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
io.Conditioning.Input("positive"),
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
io.Conditioning.Input("negative"),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
io.Vae.Input("vae"),
|
||||||
},
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"start_image": ("IMAGE",),
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
}}
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
@classmethod
|
||||||
RETURN_NAMES = ("positive", "negative", "latent")
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning/video_models"
|
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None):
|
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
if start_image is not None:
|
if start_image is not None:
|
||||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
@ -54,32 +62,36 @@ class WanImageToVideo:
|
|||||||
|
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, negative, out_latent)
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
class WanFunControlToVideo:
|
class WanFunControlToVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"positive": ("CONDITIONING",),
|
return io.Schema(
|
||||||
"negative": ("CONDITIONING",),
|
node_id="WanFunControlToVideo",
|
||||||
"vae": ("VAE",),
|
category="conditioning/video_models",
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
inputs=[
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
io.Conditioning.Input("positive"),
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
io.Conditioning.Input("negative"),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
io.Vae.Input("vae"),
|
||||||
},
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"start_image": ("IMAGE",),
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
"control_video": ("IMAGE",),
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
}}
|
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
io.Image.Input("control_video", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
@classmethod
|
||||||
RETURN_NAMES = ("positive", "negative", "latent")
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None) -> io.NodeOutput:
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning/video_models"
|
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, control_video=None):
|
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||||
@ -104,33 +116,97 @@ class WanFunControlToVideo:
|
|||||||
|
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, negative, out_latent)
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
class Wan22FunControlToVideo(io.ComfyNode):
|
||||||
class WanFirstLastFrameToVideo:
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"positive": ("CONDITIONING",),
|
return io.Schema(
|
||||||
"negative": ("CONDITIONING",),
|
node_id="Wan22FunControlToVideo",
|
||||||
"vae": ("VAE",),
|
category="conditioning/video_models",
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
inputs=[
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
io.Conditioning.Input("positive"),
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
io.Conditioning.Input("negative"),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
io.Vae.Input("vae"),
|
||||||
},
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT",),
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"clip_vision_end_image": ("CLIP_VISION_OUTPUT",),
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
"start_image": ("IMAGE",),
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
"end_image": ("IMAGE",),
|
io.Image.Input("ref_image", optional=True),
|
||||||
}}
|
io.Image.Input("control_video", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
@classmethod
|
||||||
RETURN_NAMES = ("positive", "negative", "latent")
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, start_image=None, control_video=None) -> io.NodeOutput:
|
||||||
FUNCTION = "encode"
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||||
|
concat_latent = concat_latent.repeat(1, 2, 1, 1, 1)
|
||||||
|
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||||
|
|
||||||
CATEGORY = "conditioning/video_models"
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||||
|
concat_latent[:,16:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||||
|
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
|
ref_latent = None
|
||||||
|
if ref_image is not None:
|
||||||
|
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
ref_latent = vae.encode(ref_image[:, :, :, :3])
|
||||||
|
|
||||||
|
if control_video is not None:
|
||||||
|
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
concat_latent_image = vae.encode(control_video[:, :, :, :3])
|
||||||
|
concat_latent[:,:16,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
|
||||||
|
|
||||||
|
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask, "concat_mask_index": 16})
|
||||||
|
|
||||||
|
if ref_latent is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
class WanFirstLastFrameToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="WanFirstLastFrameToVideo",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.ClipVisionOutput.Input("clip_vision_start_image", optional=True),
|
||||||
|
io.ClipVisionOutput.Input("clip_vision_end_image", optional=True),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
io.Image.Input("end_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput:
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
if start_image is not None:
|
if start_image is not None:
|
||||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
@ -171,62 +247,70 @@ class WanFirstLastFrameToVideo:
|
|||||||
|
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, negative, out_latent)
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
class WanFunInpaintToVideo:
|
class WanFunInpaintToVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"positive": ("CONDITIONING",),
|
return io.Schema(
|
||||||
"negative": ("CONDITIONING",),
|
node_id="WanFunInpaintToVideo",
|
||||||
"vae": ("VAE",),
|
category="conditioning/video_models",
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
inputs=[
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
io.Conditioning.Input("positive"),
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
io.Conditioning.Input("negative"),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
io.Vae.Input("vae"),
|
||||||
},
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"start_image": ("IMAGE",),
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
"end_image": ("IMAGE",),
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
}}
|
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
io.Image.Input("end_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
@classmethod
|
||||||
RETURN_NAMES = ("positive", "negative", "latent")
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning/video_models"
|
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
|
|
||||||
flfv = WanFirstLastFrameToVideo()
|
flfv = WanFirstLastFrameToVideo()
|
||||||
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
return flfv.execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
|
||||||
|
|
||||||
|
|
||||||
class WanVaceToVideo:
|
class WanVaceToVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"positive": ("CONDITIONING",),
|
return io.Schema(
|
||||||
"negative": ("CONDITIONING",),
|
node_id="WanVaceToVideo",
|
||||||
"vae": ("VAE",),
|
category="conditioning/video_models",
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
is_experimental=True,
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
inputs=[
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
io.Conditioning.Input("positive"),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
io.Conditioning.Input("negative"),
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
io.Vae.Input("vae"),
|
||||||
},
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"optional": {"control_video": ("IMAGE",),
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"control_masks": ("MASK",),
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
"reference_image": ("IMAGE",),
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
}}
|
io.Float.Input("strength", default=1.0, min=0.0, max=1000.0, step=0.01),
|
||||||
|
io.Image.Input("control_video", optional=True),
|
||||||
|
io.Mask.Input("control_masks", optional=True),
|
||||||
|
io.Image.Input("reference_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
io.Int.Output(display_name="trim_latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT", "INT")
|
@classmethod
|
||||||
RETURN_NAMES = ("positive", "negative", "latent", "trim_latent")
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None) -> io.NodeOutput:
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning/video_models"
|
|
||||||
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, strength, control_video=None, control_masks=None, reference_image=None):
|
|
||||||
latent_length = ((length - 1) // 4) + 1
|
latent_length = ((length - 1) // 4) + 1
|
||||||
if control_video is not None:
|
if control_video is not None:
|
||||||
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
control_video = comfy.utils.common_upscale(control_video[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
@ -283,54 +367,60 @@ class WanVaceToVideo:
|
|||||||
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, negative, out_latent, trim_latent)
|
return io.NodeOutput(positive, negative, out_latent, trim_latent)
|
||||||
|
|
||||||
|
|
||||||
class TrimVideoLatent:
|
class TrimVideoLatent(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"samples": ("LATENT",),
|
return io.Schema(
|
||||||
"trim_amount": ("INT", {"default": 0, "min": 0, "max": 99999}),
|
node_id="TrimVideoLatent",
|
||||||
}}
|
category="latent/video",
|
||||||
|
is_experimental=True,
|
||||||
RETURN_TYPES = ("LATENT",)
|
inputs=[
|
||||||
FUNCTION = "op"
|
io.Latent.Input("samples"),
|
||||||
|
io.Int.Input("trim_amount", default=0, min=0, max=99999),
|
||||||
CATEGORY = "latent/video"
|
],
|
||||||
|
outputs=[
|
||||||
EXPERIMENTAL = True
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
def op(self, samples, trim_amount):
|
)
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, samples, trim_amount) -> io.NodeOutput:
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
|
|
||||||
s1 = samples["samples"]
|
s1 = samples["samples"]
|
||||||
samples_out["samples"] = s1[:, :, trim_amount:]
|
samples_out["samples"] = s1[:, :, trim_amount:]
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
|
|
||||||
class WanCameraImageToVideo:
|
class WanCameraImageToVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"positive": ("CONDITIONING",),
|
return io.Schema(
|
||||||
"negative": ("CONDITIONING",),
|
node_id="WanCameraImageToVideo",
|
||||||
"vae": ("VAE",),
|
category="conditioning/video_models",
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
inputs=[
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
io.Conditioning.Input("positive"),
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
io.Conditioning.Input("negative"),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
io.Vae.Input("vae"),
|
||||||
},
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"start_image": ("IMAGE",),
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
"camera_conditions": ("WAN_CAMERA_EMBEDDING",),
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
}}
|
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
io.WanCameraEmbedding.Input("camera_conditions", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
@classmethod
|
||||||
RETURN_NAMES = ("positive", "negative", "latent")
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None) -> io.NodeOutput:
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning/video_models"
|
|
||||||
|
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None):
|
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
|
||||||
@ -339,9 +429,12 @@ class WanCameraImageToVideo:
|
|||||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
concat_latent_image = vae.encode(start_image[:, :, :, :3])
|
||||||
concat_latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image[:, :, :concat_latent.shape[2]]
|
concat_latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image[:, :, :concat_latent.shape[2]]
|
||||||
|
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
|
||||||
|
mask[:, :, :start_image.shape[0] + 3] = 0.0
|
||||||
|
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
|
||||||
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask})
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask})
|
||||||
|
|
||||||
if camera_conditions is not None:
|
if camera_conditions is not None:
|
||||||
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
|
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
|
||||||
@ -353,30 +446,34 @@ class WanCameraImageToVideo:
|
|||||||
|
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, negative, out_latent)
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
class WanPhantomSubjectToVideo:
|
class WanPhantomSubjectToVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"positive": ("CONDITIONING",),
|
return io.Schema(
|
||||||
"negative": ("CONDITIONING",),
|
node_id="WanPhantomSubjectToVideo",
|
||||||
"vae": ("VAE",),
|
category="conditioning/video_models",
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
inputs=[
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
io.Conditioning.Input("positive"),
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
io.Conditioning.Input("negative"),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
io.Vae.Input("vae"),
|
||||||
},
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"optional": {"images": ("IMAGE",),
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
}}
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
|
io.Image.Input("images", optional=True),
|
||||||
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
|
],
|
||||||
FUNCTION = "encode"
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
CATEGORY = "conditioning/video_models"
|
io.Conditioning.Output(display_name="negative_text"),
|
||||||
|
io.Conditioning.Output(display_name="negative_img_text"),
|
||||||
def encode(self, positive, negative, vae, width, height, length, batch_size, images):
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, images) -> io.NodeOutput:
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
cond2 = negative
|
cond2 = negative
|
||||||
if images is not None:
|
if images is not None:
|
||||||
@ -392,7 +489,7 @@ class WanPhantomSubjectToVideo:
|
|||||||
|
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, cond2, negative, out_latent)
|
return io.NodeOutput(positive, cond2, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
def parse_json_tracks(tracks):
|
def parse_json_tracks(tracks):
|
||||||
@ -613,39 +710,40 @@ def patch_motion(
|
|||||||
return out_mask_full, out_feature_full
|
return out_mask_full, out_feature_full
|
||||||
|
|
||||||
|
|
||||||
class WanTrackToVideo:
|
class WanTrackToVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"positive": ("CONDITIONING",),
|
node_id="WanTrackToVideo",
|
||||||
"negative": ("CONDITIONING",),
|
category="conditioning/video_models",
|
||||||
"vae": ("VAE",),
|
inputs=[
|
||||||
"tracks": ("STRING", {"multiline": True, "default": "[]"}),
|
io.Conditioning.Input("positive"),
|
||||||
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
io.Conditioning.Input("negative"),
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
io.Vae.Input("vae"),
|
||||||
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
io.String.Input("tracks", multiline=True, default="[]"),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"temperature": ("FLOAT", {"default": 220.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"topk": ("INT", {"default": 2, "min": 1, "max": 10}),
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
"start_image": ("IMAGE",),
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
},
|
io.Float.Input("temperature", default=220.0, min=1.0, max=1000.0, step=0.1),
|
||||||
"optional": {
|
io.Int.Input("topk", default=2, min=1, max=10),
|
||||||
"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
io.Image.Input("start_image"),
|
||||||
}}
|
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||||
|
],
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
outputs=[
|
||||||
RETURN_NAMES = ("positive", "negative", "latent")
|
io.Conditioning.Output(display_name="positive"),
|
||||||
FUNCTION = "encode"
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
CATEGORY = "conditioning/video_models"
|
],
|
||||||
|
)
|
||||||
def encode(self, positive, negative, vae, tracks, width, height, length, batch_size,
|
@classmethod
|
||||||
temperature, topk, start_image=None, clip_vision_output=None):
|
def execute(cls, positive, negative, vae, tracks, width, height, length, batch_size,
|
||||||
|
temperature, topk, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||||
|
|
||||||
tracks_data = parse_json_tracks(tracks)
|
tracks_data = parse_json_tracks(tracks)
|
||||||
|
|
||||||
if not tracks_data:
|
if not tracks_data:
|
||||||
return WanImageToVideo().encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
|
return WanImageToVideo().execute(positive, negative, vae, width, height, length, batch_size, start_image=start_image, clip_vision_output=clip_vision_output)
|
||||||
|
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||||
device=comfy.model_management.intermediate_device())
|
device=comfy.model_management.intermediate_device())
|
||||||
@ -699,34 +797,36 @@ class WanTrackToVideo:
|
|||||||
|
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, negative, out_latent)
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
class Wan22ImageToVideoLatent:
|
class Wan22ImageToVideoLatent(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"vae": ("VAE", ),
|
return io.Schema(
|
||||||
"width": ("INT", {"default": 1280, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
node_id="Wan22ImageToVideoLatent",
|
||||||
"height": ("INT", {"default": 704, "min": 32, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
category="conditioning/inpaint",
|
||||||
"length": ("INT", {"default": 49, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
inputs=[
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
io.Vae.Input("vae"),
|
||||||
},
|
io.Int.Input("width", default=1280, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
"optional": {"start_image": ("IMAGE", ),
|
io.Int.Input("height", default=704, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
}}
|
io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
RETURN_TYPES = ("LATENT",)
|
def execute(cls, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning/inpaint"
|
|
||||||
|
|
||||||
def encode(self, vae, width, height, length, batch_size, start_image=None):
|
|
||||||
latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([1, 48, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
if start_image is None:
|
if start_image is None:
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (out_latent,)
|
return io.NodeOutput(out_latent)
|
||||||
|
|
||||||
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
@ -741,18 +841,25 @@ class Wan22ImageToVideoLatent:
|
|||||||
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
|
||||||
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
|
||||||
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
|
||||||
return (out_latent,)
|
return io.NodeOutput(out_latent)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class WanExtension(ComfyExtension):
|
||||||
"WanTrackToVideo": WanTrackToVideo,
|
@override
|
||||||
"WanImageToVideo": WanImageToVideo,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"WanFunControlToVideo": WanFunControlToVideo,
|
return [
|
||||||
"WanFunInpaintToVideo": WanFunInpaintToVideo,
|
WanTrackToVideo,
|
||||||
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
|
WanImageToVideo,
|
||||||
"WanVaceToVideo": WanVaceToVideo,
|
WanFunControlToVideo,
|
||||||
"TrimVideoLatent": TrimVideoLatent,
|
Wan22FunControlToVideo,
|
||||||
"WanCameraImageToVideo": WanCameraImageToVideo,
|
WanFunInpaintToVideo,
|
||||||
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
|
WanFirstLastFrameToVideo,
|
||||||
"Wan22ImageToVideoLatent": Wan22ImageToVideoLatent,
|
WanVaceToVideo,
|
||||||
}
|
TrimVideoLatent,
|
||||||
|
WanCameraImageToVideo,
|
||||||
|
WanPhantomSubjectToVideo,
|
||||||
|
Wan22ImageToVideoLatent,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> WanExtension:
|
||||||
|
return WanExtension()
|
||||||
|
|||||||
89
comfy_extras/nodes_context_windows.py
Normal file
89
comfy_extras/nodes_context_windows.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
import comfy.context_windows
|
||||||
|
import nodes
|
||||||
|
|
||||||
|
|
||||||
|
class ContextWindowsManualNode(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ContextWindowsManual",
|
||||||
|
display_name="Context Windows (Manual)",
|
||||||
|
category="context",
|
||||||
|
description="Manually set context windows.",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||||
|
io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."),
|
||||||
|
io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."),
|
||||||
|
io.Combo.Input("context_schedule", options=[
|
||||||
|
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||||
|
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||||
|
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||||
|
comfy.context_windows.ContextSchedules.BATCHED,
|
||||||
|
], tooltip="The stride of the context window."),
|
||||||
|
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||||
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||||
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
|
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
|
||||||
|
model = model.clone()
|
||||||
|
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||||
|
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||||
|
fuse_method=comfy.context_windows.get_matching_fuse_method(fuse_method),
|
||||||
|
context_length=context_length,
|
||||||
|
context_overlap=context_overlap,
|
||||||
|
context_stride=context_stride,
|
||||||
|
closed_loop=closed_loop,
|
||||||
|
dim=dim)
|
||||||
|
# make memory usage calculation only take into account the context window latents
|
||||||
|
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||||
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
|
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
schema = super().define_schema()
|
||||||
|
schema.node_id = "WanContextWindowsManual"
|
||||||
|
schema.display_name = "WAN Context Windows (Manual)"
|
||||||
|
schema.description = "Manually set context windows for WAN-like models (dim=2)."
|
||||||
|
schema.inputs = [
|
||||||
|
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||||
|
io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window."),
|
||||||
|
io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window."),
|
||||||
|
io.Combo.Input("context_schedule", options=[
|
||||||
|
comfy.context_windows.ContextSchedules.STATIC_STANDARD,
|
||||||
|
comfy.context_windows.ContextSchedules.UNIFORM_STANDARD,
|
||||||
|
comfy.context_windows.ContextSchedules.UNIFORM_LOOPED,
|
||||||
|
comfy.context_windows.ContextSchedules.BATCHED,
|
||||||
|
], tooltip="The stride of the context window."),
|
||||||
|
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||||
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||||
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
|
]
|
||||||
|
return schema
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
|
||||||
|
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||||
|
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||||
|
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextWindowsExtension(ComfyExtension):
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
ContextWindowsManualNode,
|
||||||
|
WanContextWindowsManualNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
def comfy_entrypoint():
|
||||||
|
return ContextWindowsExtension()
|
||||||
161
comfy_extras/nodes_model_patch.py
Normal file
161
comfy_extras/nodes_model_patch.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
import torch
|
||||||
|
import folder_paths
|
||||||
|
import comfy.utils
|
||||||
|
import comfy.ops
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.latent_formats
|
||||||
|
|
||||||
|
|
||||||
|
class BlockWiseControlBlock(torch.nn.Module):
|
||||||
|
# [linear, gelu, linear]
|
||||||
|
def __init__(self, dim: int = 3072, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.x_rms = operations.RMSNorm(dim, eps=1e-6)
|
||||||
|
self.y_rms = operations.RMSNorm(dim, eps=1e-6)
|
||||||
|
self.input_proj = operations.Linear(dim, dim)
|
||||||
|
self.act = torch.nn.GELU()
|
||||||
|
self.output_proj = operations.Linear(dim, dim)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
x, y = self.x_rms(x), self.y_rms(y)
|
||||||
|
x = self.input_proj(x + y)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.output_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageBlockWiseControlNet(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_layers: int = 60,
|
||||||
|
in_dim: int = 64,
|
||||||
|
additional_in_dim: int = 0,
|
||||||
|
dim: int = 3072,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.additional_in_dim = additional_in_dim
|
||||||
|
self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype)
|
||||||
|
self.controlnet_blocks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
BlockWiseControlBlock(dim, device=device, dtype=dtype, operations=operations)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_input_latent_image(self, latent_image):
|
||||||
|
latent_image[:, :16] = comfy.latent_formats.Wan21().process_in(latent_image[:, :16])
|
||||||
|
patch_size = 2
|
||||||
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size))
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
||||||
|
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||||
|
return self.img_in(hidden_states)
|
||||||
|
|
||||||
|
def control_block(self, img, controlnet_conditioning, block_id):
|
||||||
|
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPatchLoader:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL_PATCH",)
|
||||||
|
FUNCTION = "load_model_patch"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
|
def load_model_patch(self, name):
|
||||||
|
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
|
||||||
|
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
|
||||||
|
dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
# TODO: this node will work with more types of model patches
|
||||||
|
additional_in_dim = sd["img_in.weight"].shape[1] - 64
|
||||||
|
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
model.load_state_dict(sd)
|
||||||
|
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
|
return (model,)
|
||||||
|
|
||||||
|
|
||||||
|
class DiffSynthCnetPatch:
|
||||||
|
def __init__(self, model_patch, vae, image, strength, mask=None):
|
||||||
|
self.model_patch = model_patch
|
||||||
|
self.vae = vae
|
||||||
|
self.image = image
|
||||||
|
self.strength = strength
|
||||||
|
self.mask = mask
|
||||||
|
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
|
||||||
|
|
||||||
|
def encode_latent_cond(self, image):
|
||||||
|
latent_image = self.vae.encode(image)
|
||||||
|
if self.model_patch.model.additional_in_dim > 0:
|
||||||
|
if self.mask is None:
|
||||||
|
mask_ = torch.ones_like(latent_image)[:, :self.model_patch.model.additional_in_dim // 4]
|
||||||
|
else:
|
||||||
|
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
|
||||||
|
|
||||||
|
return torch.cat([latent_image, mask_], dim=1)
|
||||||
|
else:
|
||||||
|
return latent_image
|
||||||
|
|
||||||
|
def __call__(self, kwargs):
|
||||||
|
x = kwargs.get("x")
|
||||||
|
img = kwargs.get("img")
|
||||||
|
block_index = kwargs.get("block_index")
|
||||||
|
if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]:
|
||||||
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||||
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
|
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
|
||||||
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
|
img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength)
|
||||||
|
kwargs['img'] = img
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def to(self, device_or_dtype):
|
||||||
|
if isinstance(device_or_dtype, torch.device):
|
||||||
|
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def models(self):
|
||||||
|
return [self.model_patch]
|
||||||
|
|
||||||
|
class QwenImageDiffsynthControlnet:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"model_patch": ("MODEL_PATCH",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
"optional": {"mask": ("MASK",)}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "diffsynth_controlnet"
|
||||||
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
|
CATEGORY = "advanced/loaders/qwen"
|
||||||
|
|
||||||
|
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
|
||||||
|
model_patched = model.clone()
|
||||||
|
image = image[:, :, :, :3]
|
||||||
|
if mask is not None:
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
if mask.ndim == 4:
|
||||||
|
mask = mask.unsqueeze(2)
|
||||||
|
mask = 1.0 - mask
|
||||||
|
|
||||||
|
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||||
|
return (model_patched,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ModelPatchLoader": ModelPatchLoader,
|
||||||
|
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||||
|
}
|
||||||
48
comfy_extras/nodes_qwen.py
Normal file
48
comfy_extras/nodes_qwen.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import node_helpers
|
||||||
|
import comfy.utils
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncodeQwenImageEdit:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
||||||
|
},
|
||||||
|
"optional": {"vae": ("VAE", ),
|
||||||
|
"image": ("IMAGE", ),}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
|
||||||
|
def encode(self, clip, prompt, vae=None, image=None):
|
||||||
|
ref_latent = None
|
||||||
|
if image is None:
|
||||||
|
images = []
|
||||||
|
else:
|
||||||
|
samples = image.movedim(-1, 1)
|
||||||
|
total = int(1024 * 1024)
|
||||||
|
|
||||||
|
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||||
|
width = round(samples.shape[3] * scale_by)
|
||||||
|
height = round(samples.shape[2] * scale_by)
|
||||||
|
|
||||||
|
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
|
||||||
|
image = s.movedim(1, -1)
|
||||||
|
images = [image[:, :, :, :3]]
|
||||||
|
if vae is not None:
|
||||||
|
ref_latent = vae.encode(image[:, :, :, :3])
|
||||||
|
|
||||||
|
tokens = clip.tokenize(prompt, images=images)
|
||||||
|
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||||
|
if ref_latent is not None:
|
||||||
|
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
|
||||||
|
return (conditioning, )
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"TextEncodeQwenImageEdit": TextEncodeQwenImageEdit,
|
||||||
|
}
|
||||||
0
models/model_patches/put_model_patches_here
Normal file
0
models/model_patches/put_model_patches_here
Normal file
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "comfyui"
|
name = "comfyui"
|
||||||
version = "0.3.49"
|
version = "0.3.51"
|
||||||
description = "An installable version of ComfyUI"
|
description = "An installable version of ComfyUI"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
authors = [
|
authors = [
|
||||||
@ -18,9 +18,9 @@ classifiers = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"comfyui-frontend-package>=1.24.4",
|
"comfyui-frontend-package>=1.25.9",
|
||||||
"comfyui-workflow-templates>=0.1.51",
|
"comfyui-workflow-templates>=0.1.62",
|
||||||
"comfyui-embedded-docs>=0.2.4",
|
"comfyui-embedded-docs>=0.2.6",
|
||||||
"torch",
|
"torch",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"torchdiffeq>=0.2.3",
|
"torchdiffeq>=0.2.3",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user